diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index e4e6767a444ead9..81479d9812042d6 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -5,7 +5,9 @@ ("xenial", [ ("rocm", [ ("3.5.1", [ - X("3.6"), + ("3.6", [ + ('build_only', [XImportant(True)]), + ]), ]), ]), ("gcc", [ diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index 3201d4581c9a96b..0df2c8ed920b6b3 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -307,7 +307,7 @@ def instantiate_configs(): parallel_backend = fc.find_prop("parallel_backend") or None build_only = fc.find_prop("build_only") or False is_coverage = fc.find_prop("is_coverage") or False - if build_only and restrict_phases is None: + if build_only: restrict_phases = ["build"] if is_coverage and restrict_phases is None: restrict_phases = ["build", "coverage_test"] diff --git a/.circleci/cimodel/data/simple/mobile_definitions.py b/.circleci/cimodel/data/simple/mobile_definitions.py index af48010937d84ee..9c326b779eba230 100644 --- a/.circleci/cimodel/data/simple/mobile_definitions.py +++ b/.circleci/cimodel/data/simple/mobile_definitions.py @@ -57,11 +57,6 @@ def gen_tree(self): [DOCKER_REQUIREMENT_ASAN], ["build"] ), - MobileJob( - DOCKER_IMAGE_ASAN, - [DOCKER_REQUIREMENT_ASAN], - ["custom", "build", "static"] - ), # Use LLVM-DEV toolchain in android-ndk-r19c docker image MobileJob( diff --git a/.circleci/config.yml b/.circleci/config.yml index 79176e5e19c9421..b9be825f5d06797 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -6120,54 +6120,9 @@ workflows: name: pytorch_linux_xenial_rocm3_5_1_py3_6_build requires: - "docker-pytorch-linux-xenial-rocm3.5.1-py3.6" - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-build" docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6" resource_class: xlarge - - pytorch_linux_test: - name: pytorch_linux_xenial_rocm3_5_1_py3_6_test1 - requires: - - pytorch_linux_xenial_rocm3_5_1_py3_6_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-test1" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6" - resource_class: pytorch/amd-gpu - - pytorch_linux_test: - name: pytorch_linux_xenial_rocm3_5_1_py3_6_test2 - requires: - - pytorch_linux_xenial_rocm3_5_1_py3_6_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-test2" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6" - resource_class: pytorch/amd-gpu - - pytorch_linux_test: - name: pytorch_linux_xenial_rocm3_5_1_py3_6_caffe2_test - requires: - - pytorch_linux_xenial_rocm3_5_1_py3_6_build - filters: - branches: - only: - - master - - /ci-all\/.*/ - - /release\/.*/ - build_environment: "pytorch-linux-xenial-rocm3.5.1-py3.6-caffe2_test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-rocm3.5.1-py3.6" - resource_class: pytorch/amd-gpu - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc5_4_build requires: @@ -6652,13 +6607,6 @@ workflows: name: pytorch_linux_xenial_py3_clang5_mobile_build requires: - docker-pytorch-linux-xenial-py3-clang5-asan - - pytorch_linux_build: - build_environment: pytorch-linux-xenial-py3-clang5-mobile-custom-build-static - build_only: "1" - docker_image: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan - name: pytorch_linux_xenial_py3_clang5_mobile_custom_build_static - requires: - - docker-pytorch-linux-xenial-py3-clang5-asan - pytorch_linux_build: build_environment: pytorch-linux-xenial-py3-clang5-mobile-custom-build-dynamic build_only: "1" diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index d310a125a98ec95..c2057901b2d7bda 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -160,7 +160,6 @@ case "$image" in KATEX=yes ;; pytorch-linux-xenial-cuda11.0-cudnn8-py3-gcc7) - UBUNTU_VERSION=16.04-rc CUDA_VERSION=11.0 CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.6 @@ -230,7 +229,6 @@ case "$image" in VISION=yes ;; pytorch-linux-bionic-cuda11.0-cudnn8-py3.6-gcc9) - UBUNTU_VERSION=18.04-rc CUDA_VERSION=11.0 CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.6 @@ -241,7 +239,6 @@ case "$image" in KATEX=yes ;; pytorch-linux-bionic-cuda11.0-cudnn8-py3.8-gcc9) - UBUNTU_VERSION=18.04-rc CUDA_VERSION=11.0 CUDNN_VERSION=8 ANACONDA_PYTHON_VERSION=3.8 diff --git a/.circleci/docker/common/install_conda.sh b/.circleci/docker/common/install_conda.sh index b25ab843006594e..2a1c2bd0ea8fad1 100755 --- a/.circleci/docker/common/install_conda.sh +++ b/.circleci/docker/common/install_conda.sh @@ -86,6 +86,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then conda_install magma-cuda101 -c pytorch elif [[ "$CUDA_VERSION" == 10.2* ]]; then conda_install magma-cuda102 -c pytorch + elif [[ "$CUDA_VERSION" == 11.0* ]]; then + conda_install magma-cuda110 -c pytorch fi # TODO: This isn't working atm diff --git a/.circleci/scripts/binary_upload.sh b/.circleci/scripts/binary_upload.sh index 139f5e590c8ea8b..9282674d3a47584 100755 --- a/.circleci/scripts/binary_upload.sh +++ b/.circleci/scripts/binary_upload.sh @@ -30,7 +30,7 @@ do_backup() { ( pushd /tmp/workspace set -x - ${AWS_S3_CP} --recursive . "${BACKUP_BUCKET}/${CIRCLE_TAG}/${backup_dir}" + ${AWS_S3_CP} --recursive . "${BACKUP_BUCKET}/${CIRCLE_TAG}/${backup_dir}/" ) } @@ -52,7 +52,7 @@ s3_upload() { local pkg_type extension="$1" pkg_type="$2" - s3_dir="${UPLOAD_BUCKET}/${pkg_type}/${UPLOAD_CHANNEL}/${UPLOAD_SUBFOLDER}" + s3_dir="${UPLOAD_BUCKET}/${pkg_type}/${UPLOAD_CHANNEL}/${UPLOAD_SUBFOLDER}/" ( for pkg in ${PKG_DIR}/*.${extension}; do ( diff --git a/.jenkins/pytorch/build-mobile.sh b/.jenkins/pytorch/build-mobile.sh index 03b836b76c3f67f..b1234f2728130e5 100755 --- a/.jenkins/pytorch/build-mobile.sh +++ b/.jenkins/pytorch/build-mobile.sh @@ -22,9 +22,7 @@ retry pip install --pre torch torchvision \ # Run end-to-end process of building mobile library, linking into the predictor # binary, and running forward pass with a real model. -if [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-static* ]]; then - TEST_CUSTOM_BUILD_STATIC=1 test/mobile/custom_build/build.sh -elif [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-dynamic* ]]; then +if [[ "$BUILD_ENVIRONMENT" == *-mobile-custom-build-dynamic* ]]; then export LLVM_DIR="$(llvm-config-5.0 --prefix)" echo "LLVM_DIR: ${LLVM_DIR}" TEST_CUSTOM_BUILD_DYNAMIC=1 test/mobile/custom_build/build.sh diff --git a/BUILD.bazel b/BUILD.bazel index da50ea2d4b8c424..f7be71ec624d1d2 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -2,7 +2,7 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") load("@rules_proto//proto:defs.bzl", "proto_library") load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_proto_library", "cc_test") -load("//third_party:substitution.bzl", "template_rule") +load("//third_party:substitution.bzl", "header_template_rule") load("//:tools/build_variables.bzl", "torch_cpp_srcs", "libtorch_python_core_sources", "libtorch_core_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "jit_core_sources") load("//tools/rules:cu.bzl", "cu_library") load("//tools/config:defs.bzl", "if_cuda") @@ -27,19 +27,18 @@ COMMON_COPTS = [ ]) # c10 -template_rule( +header_template_rule( name = "cmake_macros_h", src = "c10/macros/cmake_macros.h.in", out = "c10/macros/cmake_macros.h", substitutions = { "cmakedefine": "define", "#define FEATURE_TORCH_MOBILE": "/* #undef FEATURE_TORCH_MOBILE */", - "#define USE_STATIC_DISPATCH": "/* #undef USE_STATIC_DISPATCH */", "#define C10_USE_NUMA": "/* #undef C10_USE_NUMA */", }, ) -template_rule( +header_template_rule( name = "cuda_cmake_macros_h", src = "c10/cuda/impl/cuda_cmake_macros.h.in", out = "c10/cuda/impl/cuda_cmake_macros.h", @@ -58,13 +57,12 @@ cc_library( "c10/macros/*.h", "c10/util/*.h", "c10/util/*.hpp", - ]) + [ - "c10/macros/cmake_macros.h", - "c10/cuda/impl/cuda_cmake_macros.h", - ], + ]), deps = [ "@com_github_gflags_gflags//:gflags", "@com_github_glog//:glog", + ":cmake_macros_h", + ":cuda_cmake_macros_h", ], ) @@ -531,7 +529,7 @@ filegroup( ], ) -template_rule( +header_template_rule( name = "aten_src_ATen_config", src = "aten/src/ATen/Config.h.in", out = "aten/src/ATen/Config.h", @@ -547,7 +545,7 @@ template_rule( }, ) -template_rule( +header_template_rule( name = "aten_src_ATen_cuda_config", src = "aten/src/ATen/cuda/CUDAConfig.h.in", out = "aten/src/ATen/cuda/CUDAConfig.h", @@ -558,7 +556,7 @@ template_rule( }, ) -template_rule( +header_template_rule( name = "aten_src_TH_THGeneral", src = "aten/src/TH/THGeneral.h.in", out = "aten/src/TH/THGeneral.h", @@ -570,7 +568,7 @@ template_rule( }, ) -template_rule( +header_template_rule( name = "aten_src_THC_THCGeneral", src = "aten/src/THC/THCGeneral.h.in", out = "aten/src/THC/THCGeneral.h", @@ -582,8 +580,6 @@ template_rule( cc_library( name = "aten_headers", hdrs = [ - "aten/src/TH/THGeneral.h", - "aten/src/THC/THCGeneral.h", "torch/csrc/WindowsTorchApiMacro.h", "torch/csrc/jit/frontend/function_schema_parser.h", ] + glob([ @@ -605,6 +601,8 @@ cc_library( ], deps = [ ":c10_headers", + ":aten_src_TH_THGeneral", + ":aten_src_THC_THCGeneral", ], ) @@ -766,7 +764,7 @@ cc_proto_library( deps = [":caffe2_proto_source"], ) -template_rule( +header_template_rule( name = "caffe2_core_macros_h", src = "caffe2/core/macros.h.in", out = "caffe2/core/macros.h", @@ -1586,7 +1584,6 @@ filegroup( cc_library( name = "caffe2_for_aten_headers", hdrs = [ - "caffe2/core/macros.h", "caffe2/core/common.h", "caffe2/core/logging.h", "caffe2/core/types.h", @@ -1604,6 +1601,7 @@ cc_library( deps = [ ":c10_headers", ":caffe2_protos", + ":caffe2_core_macros_h", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index a842fe3ba9ee7a9..fb7b306547bbeb9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,7 +124,6 @@ option(BUILD_PYTHON "Build Python binaries" ON) option(BUILD_CAFFE2_OPS "Build Caffe2 operators" ON) option(BUILD_SHARED_LIBS "Build libcaffe2.so" ON) option(BUILD_CAFFE2_MOBILE "Build libcaffe2 for mobile (deprecating)" OFF) -option(USE_STATIC_DISPATCH "Use static dispatch for ATen operators" OFF) cmake_dependent_option( CAFFE2_LINK_LOCAL_PROTOBUF "If set, build protobuf inside libcaffe2.so." ON "BUILD_SHARED_LIBS AND BUILD_CUSTOM_PROTOBUF" OFF) diff --git a/README.md b/README.md index ce9d81c3d07caf0..6e1fcfdb828f833 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ PyTorch is a Python package that provides two high-level features: - Tensor computation (like NumPy) with strong GPU acceleration - Deep neural networks built on a tape-based autograd system -You can reuse your favorite Python packages such as NumPy, SciPy and Cython to extend PyTorch when needed. +You can reuse your favorite Python packages such as NumPy, SciPy, and Cython to extend PyTorch when needed. - [More about PyTorch](#more-about-pytorch) - [Installation](#installation) @@ -45,7 +45,7 @@ At a granular level, PyTorch is a library that consists of the following compone | [**torch.multiprocessing**](https://pytorch.org/docs/stable/multiprocessing.html) | Python multiprocessing, but with magical memory sharing of torch Tensors across processes. Useful for data loading and Hogwild training | | [**torch.utils**](https://pytorch.org/docs/stable/data.html) | DataLoader and other utility functions for convenience | -Usually PyTorch is used either as: +Usually, PyTorch is used either as: - a replacement for NumPy to use the power of GPUs. - a deep learning research platform that provides maximum flexibility and speed. @@ -58,7 +58,7 @@ If you use NumPy, then you have used Tensors (a.k.a. ndarray). ![Tensor illustration](./docs/source/_static/img/tensor_illustration.png) -PyTorch provides Tensors that can live either on the CPU or the GPU, and accelerates the +PyTorch provides Tensors that can live either on the CPU or the GPU and accelerates the computation by a huge amount. We provide a wide variety of tensor routines to accelerate and fit your scientific computation needs @@ -69,8 +69,8 @@ And they are fast! PyTorch has a unique way of building neural networks: using and replaying a tape recorder. -Most frameworks such as TensorFlow, Theano, Caffe and CNTK have a static view of the world. -One has to build a neural network, and reuse the same structure again and again. +Most frameworks such as TensorFlow, Theano, Caffe, and CNTK have a static view of the world. +One has to build a neural network and reuse the same structure again and again. Changing the way the network behaves means that one has to start from scratch. With PyTorch, we use a technique called reverse-mode auto-differentiation, which allows you to @@ -96,9 +96,9 @@ Our goal is to not reinvent the wheel where appropriate. ### Imperative Experiences -PyTorch is designed to be intuitive, linear in thought and easy to use. +PyTorch is designed to be intuitive, linear in thought, and easy to use. When you execute a line of code, it gets executed. There isn't an asynchronous view of the world. -When you drop into a debugger, or receive error messages and stack traces, understanding them is straightforward. +When you drop into a debugger or receive error messages and stack traces, understanding them is straightforward. The stack trace points to exactly where your code was defined. We hope you never spend hours debugging your code because of bad stack traces or asynchronous and opaque execution engines. @@ -220,10 +220,10 @@ If the version of Visual Studio 2017 is higher than 15.6, installing of "VC++ 20 NVTX is a part of CUDA distributive, where it is called "Nsight Compute". To install it onto already installed CUDA run CUDA installation once again and check the corresponding checkbox. Be sure that CUDA with Nsight Compute is installed after Visual Studio 2017. -Currently VS 2017, VS 2019 and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise it will use VS 2017. +Currently, VS 2017, VS 2019, and Ninja are supported as the generator of CMake. If `ninja.exe` is detected in `PATH`, then Ninja will be used as the default generator, otherwise, it will use VS 2017.
If Ninja is selected as the generator, the latest MSVC which is newer than VS 2015 (14.0) will get selected as the underlying toolchain. If you use CMake <= 3.14.2 and has VS 2019 installed, then even if you specify VS 2017 as the generator, VS 2019 will get selected as the generator. -CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the cuda host compiler (not recommended if there are big version differences). +CUDA and MSVC have strong version dependencies, so even if you use VS 2017 / 2019, you will get build errors like `nvcc fatal : Host compiler targets unsupported OS`. For this kind of problem, please install the corresponding VS toolchain in the table below, and then you can either specify the toolset during activation (recommended) or set `CUDAHOSTCXX` to override the Cuda host compiler (not recommended if there are big version differences). | CUDA version | Newest supported VS version | | ------------ | ------------------------------------------------------- | @@ -246,7 +246,7 @@ set CMAKE_GENERATOR_TOOLSET_VERSION=14.11 set DISTUTILS_USE_SDK=1 for /f "usebackq tokens=*" %i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -version [15^,16^) -products * -latest -property installationPath`) do call "%i\VC\Auxiliary\Build\vcvarsall.bat" x64 -vcvars_ver=%CMAKE_GENERATOR_TOOLSET_VERSION% -:: [Optional] If you want to override the cuda host compiler +:: [Optional] If you want to override the Cuda host compiler set CUDAHOSTCXX=C:\Program Files (x86)\Microsoft Visual Studio\2017\Enterprise\VC\Tools\MSVC\14.11.25503\bin\HostX64\x64\cl.exe python setup.py install @@ -291,7 +291,7 @@ should increase shared memory size either with `--ipc=host` or `--shm-size` comm **NOTE:** Must be built with a docker version > 18.06 -The `Dockerfile` is supplied to build images with cuda support and cudnn v7. +The `Dockerfile` is supplied to build images with Cuda support and cuDNN v7. You can pass `PYTHON_VERSION=x.y` make variable to specify which Python version is to be used by Miniconda, or leave it unset to use the default. ```bash @@ -319,7 +319,7 @@ on [our website](https://pytorch.org/previous-versions). ## Getting Started -Three pointers to get you started: +Three-pointers to get you started: - [Tutorials: get you started with understanding and using PyTorch](https://pytorch.org/tutorials/) - [Examples: easy to understand pytorch code across all domains](https://github.com/pytorch/examples) - [The API Reference](https://pytorch.org/docs/) @@ -341,31 +341,31 @@ Three pointers to get you started: ## Communication * forums: discuss implementations, research, etc. https://discuss.pytorch.org * GitHub issues: bug reports, feature requests, install issues, RFCs, thoughts, etc. -* Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1 -* newsletter: no-noise, one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv +* Slack: The [PyTorch Slack](https://pytorch.slack.com/) hosts a primary audience of moderate to experienced PyTorch users and developers for general chat, online discussions, collaboration, etc. If you are a beginner looking for help, the primary medium is [PyTorch Forums](https://discuss.pytorch.org). If you need a slack invite, please fill this form: https://goo.gl/forms/PP1AGvNHpSaJP8to1 +* newsletter: no-noise, a one-way email newsletter with important announcements about PyTorch. You can sign-up here: https://eepurl.com/cbG0rv * Facebook page: important announcements about PyTorch. https://www.facebook.com/pytorch * for brand guidelines, please visit our website at [pytorch.org](https://pytorch.org/) ## Releases and Contributing -PyTorch has a 90 day release cycle (major releases). Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues). +PyTorch has a 90-day release cycle (major releases). Please let us know if you encounter a bug by [filing an issue](https://github.com/pytorch/pytorch/issues). We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. -If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. -Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. +If you plan to contribute new features, utility functions, or extensions to the core, please first open an issue and discuss the feature with us. +Sending a PR without discussion might end up resulting in a rejected PR because we might be taking the core in a different direction than you might be aware of. To learn more about making a contribution to Pytorch, please see our [Contribution page](CONTRIBUTING.md). ## The Team -PyTorch is a community driven project with several skillful engineers and researchers contributing to it. +PyTorch is a community-driven project with several skillful engineers and researchers contributing to it. PyTorch is currently maintained by [Adam Paszke](https://apaszke.github.io/), [Sam Gross](https://github.com/colesbury), [Soumith Chintala](http://soumith.ch) and [Gregory Chanan](https://github.com/gchanan) with major contributions coming from hundreds of talented individuals in various forms and means. A non-exhaustive but growing list needs to mention: Trevor Killeen, Sasank Chilamkurthy, Sergey Zagoruyko, Adam Lerer, Francisco Massa, Alykhan Tejani, Luca Antiga, Alban Desmaison, Andreas Koepf, James Bradbury, Zeming Lin, Yuandong Tian, Guillaume Lample, Marat Dukhan, Natalia Gimelshein, Christian Sarofeen, Martin Raison, Edward Yang, Zachary Devito. -Note: this project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor in the Torch community and has helped with many things Torch and PyTorch. +Note: this project is unrelated to [hughperkins/pytorch](https://github.com/hughperkins/pytorch) with the same name. Hugh is a valuable contributor to the Torch community and has helped with many things Torch and PyTorch. ## License -PyTorch is BSD-style licensed, as found in the [LICENSE](LICENSE) file. +PyTorch is a BSD-style licensed, as found in the [LICENSE](LICENSE) file. diff --git a/aten/src/ATen/BatchedFallback.cpp b/aten/src/ATen/BatchedFallback.cpp index 7869db6006bc365..aaafc341334c4e6 100644 --- a/aten/src/ATen/BatchedFallback.cpp +++ b/aten/src/ATen/BatchedFallback.cpp @@ -123,12 +123,14 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta // We assume that torch::jit::Stack is backed by vector for // simplicity. When that is not the case, this code should be updated. const auto& argument = (*stack)[arguments_begin + arg_idx]; - if (arg_idx != *batched_tensor_inputs_pos_iter) { + if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end() + || arg_idx != *batched_tensor_inputs_pos_iter) { // argument isn't a BatchedTensor torch::jit::push(stack, argument); continue; } // argument is a BatchedTensor + TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end()); const auto& physical_view_for_argument = *input_physical_views_iter; torch::jit::push(stack, physical_view_for_argument.tensor().index(index)); batched_tensor_inputs_pos_iter++; diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index e97b3d368d1c600..33ac088b70a7c3c 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -64,7 +64,7 @@ file(GLOB native_cpp "native/*.cpp") file(GLOB native_mkl_cpp "native/mkl/*.cpp") file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") file(GLOB vulkan_cpp "vulkan/*.cpp") -file(GLOB native_vulkan_cpp "native/vulkan/*.cpp") +file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp") file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp "native/quantized/*.cpp" diff --git a/aten/src/ATen/MemoryOverlap.h b/aten/src/ATen/MemoryOverlap.h index 7af936b60c1b7fa..67f63a64668c35d 100644 --- a/aten/src/ATen/MemoryOverlap.h +++ b/aten/src/ATen/MemoryOverlap.h @@ -24,7 +24,7 @@ CAFFE2_API void assert_no_internal_overlap(TensorImpl* t); CAFFE2_API MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b); CAFFE2_API MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b); -void assert_no_partial_overlap(const Tensor& a, const Tensor& b); +CAFFE2_API void assert_no_partial_overlap(const Tensor& a, const Tensor& b); void assert_no_partial_overlap(TensorImpl* a, TensorImpl* b); } diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 74c4e0b69cf3169..c7180f9d558ecf0 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -202,8 +202,6 @@ Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const c10::op "safe to autocast."); } - -#ifndef USE_STATIC_DISPATCH namespace { /***************************************************************************************************************** This section performs load-time registration for autocast wrappers. @@ -378,7 +376,6 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { } } -#endif } // namespace autocast } // namespace at diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index cd6fa93515da251..57a36b135eaf429 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -23,8 +23,6 @@ _(aten, _abs) \ _(aten, _addmv) \ _(aten, _addr) \ _(aten, _arange) \ -_(aten, _asinh) \ -_(aten, _atanh) \ _(aten, _argmax) \ _(aten, _argmin) \ _(aten, _baddbmm_mkl) \ @@ -656,8 +654,6 @@ _(aten, stft) \ _(aten, storage_offset) \ _(aten, stride) \ _(aten, strides) \ -_(aten, sub) \ -_(aten, sub_) \ _(aten, rsub) \ _(aten, sum) \ _(aten, sum_to_size) \ diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index a38044b1a304e85..9738bfaabb8b7be 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -95,6 +95,7 @@ namespace c10 { _(prim, Guard) \ _(prim, BailOut) \ _(prim, TypeCheck) \ + _(prim, FallbackGraph) \ _(prim, FusedConcat) \ _(prim, ConstantChunk) \ _(prim, MMTreeReduce) \ @@ -130,6 +131,7 @@ namespace c10 { _(prim, GetAttr) \ _(prim, HasAttr) \ _(prim, profile) \ + _(prim, profile_optional) \ _(prim, AddStatValue) \ _(prim, TimePoint) \ _(prim, CallFunction) \ @@ -156,16 +158,25 @@ namespace c10 { _(aten, asin_) \ _(aten, arcsin) \ _(aten, arcsin_) \ + _(aten, asinh) \ + _(aten, asinh_) \ + _(aten, arcsinh) \ + _(aten, arcsinh_) \ _(aten, atan) \ _(aten, atan_) \ _(aten, arctan) \ _(aten, arctan_) \ + _(aten, atanh) \ + _(aten, atanh_) \ + _(aten, arctanh) \ + _(aten, arctanh_) \ _(aten, clamp) \ _(aten, clamp_) \ _(aten, clip) \ _(aten, clip_) \ _(aten, det) \ _(aten, linalg_det) \ + _(aten, linalg_norm) \ _(aten, append) \ _(aten, item) \ _(aten, format) \ @@ -203,6 +214,10 @@ namespace c10 { _(aten, list) \ _(aten, wait) \ _(aten, save) \ + _(aten, sub) \ + _(aten, sub_) \ + _(aten, subtract) \ + _(aten, subtract_) \ _(aten, keys) \ _(aten, ord) \ _(aten, chr) \ @@ -293,6 +308,8 @@ namespace c10 { _(attr, inplace) \ _(attr, input_as_shape) \ _(attr, is_zero) \ + _(attr, num_none) \ + _(attr, num_present) \ _(attr, perm) \ _(attr, sizes) \ _(attr, starts) \ diff --git a/aten/src/ATen/core/op_registration/op_whitelist.h b/aten/src/ATen/core/op_registration/op_whitelist.h index 92d71e2628c89e3..cf7f09bb0f9a3e7 100644 --- a/aten/src/ATen/core/op_registration/op_whitelist.h +++ b/aten/src/ATen/core/op_registration/op_whitelist.h @@ -60,9 +60,13 @@ constexpr bool op_whitelist_check(string_view op_name) { #else return op_whitelist_contains( C10_STRINGIZE(TORCH_OPERATOR_WHITELIST), - // Strip overload name (as whitelist doesn't contain overloads) - OperatorNameView::parse(op_name).name - ); + // This function is majorly used for mobile selective build with + // root operators, where the overload is included in the whitelist. + op_name); + // // Strip overload name (as whitelist doesn't contain overloads) + // // Another function based on this may be added when there's usage + // // on op names without overload. + // OperatorNameView::parse(op_name).name); #endif } @@ -76,6 +80,12 @@ constexpr bool schema_whitelist_check(string_view schema) { #endif } +// schema_whitelist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST. +// Add this API to pass arbitrary whitelist. +constexpr bool op_whitelist_contains_name_in_schema(string_view whitelist, string_view schema) { + return op_whitelist_contains(whitelist, schema.substr(0, schema.find("("))); +} + // Returns true iff the given dispatch key is on the whitelist // and should be registered. When we turn this on, the list of valid // mobile dispatch keys is hard coded (but you need to make sure diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp index 4a900f1309b6a50..31e1a26e8fb7eca 100644 --- a/aten/src/ATen/cudnn/AutocastRNN.cpp +++ b/aten/src/ATen/cudnn/AutocastRNN.cpp @@ -104,14 +104,12 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, #endif // AT_CUDNN_ENABLED() } -#ifndef USE_STATIC_DISPATCH namespace { TORCH_LIBRARY_IMPL(aten, Autocast, m) { m.impl("_cudnn_rnn", TORCH_FN((&at::autocast::_cudnn_rnn_cast_reflatten))); } } // anonymous namespace -#endif } // namespace autocast } // namespace at diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index e26bb3941b3a7e6..f996e73e5d9ccdb 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -146,14 +146,10 @@ def TypedDict(name, attrs, total=True): # type: ignore // ${schema_string} ${return_type} Tensor::${api_name}(${method_formals}) const { -#ifdef USE_STATIC_DISPATCH - ${static_dispatch_method_body} -#else static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::${operator_name}", "${overload_name}") .typed<${tensor_method_cpp_signature}>(); return op.call(${tensor_method_actuals}); -#endif } """) @@ -172,45 +168,13 @@ def TypedDict(name, attrs, total=True): # type: ignore // ${schema_string} ${return_type} ${api_name}(${formals}) { -#ifdef USE_STATIC_DISPATCH - ${static_dispatch_function_body} -#else static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("aten::${operator_name}", "${overload_name}") .typed<${function_cpp_signature}>(); return op.call(${function_actuals}); -#endif -} -""") - -# In order to rely on the linker to strip unused ops, it requires us to dispatch statically -# in Functions.h and TensorMethods.cpp. -# -# NB: The default body also needs to apply a variable guard, as in some -# situations what we think is a default body actually does have an -# explicit derivative, and thereby would have gotten unwrapped by -# the time you get to the implementation. -STATIC_DISPATCH_FUNCTION_DEFAULT_BODY = CodeTemplate("""\ -at::AutoNonVariableTypeMode _var_guard(true); -${return_call} TypeDefault::${type_wrapper_name}(${actuals}); -""") - -STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\ -at::AutoNonVariableTypeMode _var_guard(true); -${dispatch_key_init} -switch (dispatchKeyToBackend(${dispatch_key_var_name})) { - ${static_dispatch_function_cases} - default: - AT_ERROR("${api_name} not implemented for ", at::toString(${dispatch_key_var_name})); } """) -STATIC_DISPATCH_FUNCTION_SWITCH_CASE = CodeTemplate("""\ -case Backend::${backend}: - ${return_call} ${backend}Type::${type_wrapper_name}(${actuals}); - break; -""") - IFDEF_BLOCK = CodeTemplate("""\ #ifdef ${ifdef_guard} ${content} @@ -246,10 +210,6 @@ def TypedDict(name, attrs, total=True): # type: ignore ('ComplexDouble', 'ComplexDouble', 'ComplexDouble', False), ] -static_dispatch_backends = ['CPU', 'QuantizedCPU', 'Vulkan'] -static_dispatch_backends_ifdef_guard = {'Vulkan' : 'USE_VULKAN'} - - class NYIError(Exception): """Indicates we don't support this declaration yet""" @@ -1136,44 +1096,6 @@ def swizzle_self(f): # blegh method_actuals = maybe_unwrap_optional_tensors(option, formals, option['method_actuals']) - if isinstance(type_method_dispatch, dict): - static_dispatch_function_cases = [] - # NB: As this code is currently written, there will NEVER be - # a backend generated for variable dispatch. There is nothing - # stopping us from actually implementing this, however, if you - # really wanted variable on mobile, there's nothing stopping - # you from implementing this (however, you would have an - # annoying phase problem, since code generation for variable - # happens in tools/ which happens later than here.) - # - # If you pass in a variable to the dispatch, and variable is - # enabled, this switch will fail. This is intentional: you - # probably need to disable variable globally in the mobile - # calling code. - for backend in static_dispatch_backends: - if backend in type_method_dispatch: - static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute( - option, - backend=backend, - backend_function=type_method_dispatch[backend], - actuals=method_actuals) - if (backend in static_dispatch_backends_ifdef_guard): - static_dispatch_function_cases.append(IFDEF_BLOCK.substitute( - option, - ifdef_guard=static_dispatch_backends_ifdef_guard[backend], - content=static_dispatch_function_case)) - else: - static_dispatch_function_cases.append(static_dispatch_function_case) - - static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( - option, - dispatch_key_var_name=dispatch_key_var_name, - dispatch_key_init=dispatch_key_init, - static_dispatch_function_cases=static_dispatch_function_cases) - else: - static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( - option, actuals=method_actuals) - # See NOTE[UnboxedOnly] if option['use_c10_dispatcher'] == 'full': tensor_method_actuals = option['schema_order_method_actuals'] @@ -1184,13 +1106,12 @@ def swizzle_self(f): # blegh tensor_method_cpp_signature = option['cpp_signature'] method_definition = TENSOR_METHOD_DEFINITION.substitute( - option, static_dispatch_method_body=static_dispatch_method_body, + option, tensor_method_actuals=tensor_method_actuals, tensor_method_cpp_signature=tensor_method_cpp_signature ) return FunctionCode( - declaration=TENSOR_METHOD_DECLARATION.substitute( - option, static_dispatch_method_body=static_dispatch_method_body), + declaration=TENSOR_METHOD_DECLARATION.substitute(option), definition=method_definition) def gen_namespace_function(option, multidispatch_formals): @@ -1204,31 +1125,6 @@ def gen_namespace_function(option, multidispatch_formals): actuals = maybe_unwrap_optional_tensors(option, formals, option['actuals']) - if isinstance(type_method_dispatch, dict): - static_dispatch_function_cases = [] - for backend in static_dispatch_backends: - if backend in type_method_dispatch: - static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute( - option, - backend=backend, - backend_function=type_method_dispatch[backend], - actuals=actuals) - if (backend in static_dispatch_backends_ifdef_guard): - static_dispatch_function_cases.append(IFDEF_BLOCK.substitute( - option, - ifdef_guard=static_dispatch_backends_ifdef_guard[backend], - content=static_dispatch_function_case)) - else: - static_dispatch_function_cases.append(static_dispatch_function_case) - static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( - option, - dispatch_key_var_name=dispatch_key_var_name, - dispatch_key_init=dispatch_key_init, - static_dispatch_function_cases=static_dispatch_function_cases) - else: - static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( - option, actuals=actuals) - # See NOTE[UnboxedOnly] if option['use_c10_dispatcher'] == 'full': function_actuals = option['schema_order_actuals'] @@ -1239,7 +1135,7 @@ def gen_namespace_function(option, multidispatch_formals): function_cpp_signature = option['cpp_signature'] fn_definition = FUNCTION_DEFINITION.substitute( - option, static_dispatch_function_body=static_dispatch_function_body, + option, function_actuals=function_actuals, function_cpp_signature=function_cpp_signature) diff --git a/aten/src/ATen/native/Activation.cpp b/aten/src/ATen/native/Activation.cpp index 0e301df3c9afdef..a43f16465db84e8 100644 --- a/aten/src/ATen/native/Activation.cpp +++ b/aten/src/ATen/native/Activation.cpp @@ -372,7 +372,16 @@ static Tensor threshold_out( Scalar value, const Tensor& other) { Tensor result = opt_result.value_or(Tensor()); - auto iter = TensorIterator::binary_op(result, self, other); + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay + .add_output(result) + .add_input(self) + .add_input(other) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .build(); threshold_stub(iter.device_type(), iter, threshold, value); return iter.output(); } diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index 50a758d8f3b240f..bccba591a529cc3 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -48,9 +48,14 @@ DEFINE_DISPATCH(lcm_stub); DEFINE_DISPATCH(hypot_stub); DEFINE_DISPATCH(nextafter_stub); +static Tensor wrapped_scalar_tensor(Scalar scalar) { + auto tensor = scalar_to_tensor(scalar); + tensor.unsafeGetTensorImpl()->set_wrapped_number(true); + return tensor; +} + Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); alpha_check(iter.dtype(), alpha); add_stub(iter.device_type(), iter, alpha); TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); @@ -71,8 +76,7 @@ Tensor& add_(Tensor& self, const Tensor& other, Scalar alpha) { Tensor& add_relu_impl( Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); Scalar min_val; Scalar max_val; if (self.dtype() == at::kInt) { @@ -124,8 +128,7 @@ Tensor& div_out(Tensor& result, const Tensor& self, const Tensor& other) { "Use true_divide or floor_divide (// in Python) instead."); } - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); div_stub(iter.device_type(), iter); return result; } @@ -150,8 +153,7 @@ Tensor& div_(Tensor& self, const Tensor& other) { } Tensor& remainder_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); remainder_stub(iter.device_type(), iter); return result; } @@ -212,8 +214,7 @@ Tensor& true_divide_(Tensor& self, const Tensor& divisor) { } Tensor& floor_divide_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); div_stub(iter.device_type(), iter); if (result.is_floating_point()) { @@ -242,8 +243,7 @@ Tensor& floor_divide_(Tensor& self, const Tensor& other) { } Tensor& mul_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); mul_stub(iter.device_type(), iter); return result; } @@ -261,8 +261,7 @@ Tensor& mul_(Tensor& self, const Tensor& other) { Tensor& sub_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { sub_check(self, other); - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); alpha_check(iter.dtype(), alpha); sub_stub(iter.device_type(), iter, alpha); TORCH_INTERNAL_ASSERT(result.scalar_type() == iter.output().dtype()); @@ -282,6 +281,35 @@ Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) { return native::sub_out(self, self, other, alpha); } +Tensor sub(const Tensor& self, Scalar other, Scalar alpha) { + return native::sub(self, wrapped_scalar_tensor(other), alpha); +} + +Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) { + return native::sub_(self, wrapped_scalar_tensor(other), alpha); +} + +// subtract, alias for sub +Tensor& subtract_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) { + return at::sub_out(result, self, other, alpha); +} + +Tensor subtract(const Tensor& self, const Tensor& other, Scalar alpha) { + return self.sub(other, alpha); +} + +Tensor& subtract_(Tensor& self, const Tensor& other, Scalar alpha) { + return self.sub_(other, alpha); +} + +Tensor subtract(const Tensor& self, Scalar other, Scalar alpha) { + return self.sub(other, alpha); +} + +Tensor& subtract_(Tensor& self, Scalar other, Scalar alpha) { + return self.sub_(other, alpha); +} + Tensor& sigmoid_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) { auto iter = TensorIterator::binary_op(result, grad_output, output); sigmoid_backward_stub(iter.device_type(), iter); @@ -353,12 +381,6 @@ Tensor& atan2_(Tensor& self, const Tensor& other) { // types (int, float, etc.) to Tensor (only to Scalar). They're not exposed // to Python. -static Tensor wrapped_scalar_tensor(Scalar scalar) { - auto tensor = scalar_to_tensor(scalar); - tensor.unsafeGetTensorImpl()->set_wrapped_number(true); - return tensor; -} - static void check_convert(Scalar scalar, ScalarType scalarType) { // Validate that is possible to convert scalar to tensor dtype without overflow AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{ @@ -425,21 +447,12 @@ Tensor& mul_(Tensor& self, Scalar other) { return native::mul_(self, wrapped_scalar_tensor(other)); } -Tensor sub(const Tensor& self, Scalar other, Scalar alpha) { - return native::sub(self, wrapped_scalar_tensor(other), alpha); -} - -Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) { - return native::sub_(self, wrapped_scalar_tensor(other), alpha); -} - Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) { return native::rsub(self, wrapped_scalar_tensor(other), alpha); } Tensor& bitwise_and_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); bitwise_and_stub(iter.device_type(), iter); return result; } @@ -485,8 +498,7 @@ Tensor& __iand__(Tensor& self, Scalar other) { } Tensor& bitwise_or_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); bitwise_or_stub(iter.device_type(), iter); return result; } @@ -532,8 +544,7 @@ Tensor& __ior__(Tensor& self, Scalar other) { } Tensor& bitwise_xor_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); bitwise_xor_stub(iter.device_type(), iter); return result; } @@ -644,7 +655,7 @@ Tensor& comparison_op_out(Tensor& result, const Tensor& self, const Tensor& othe check_convert(self.item(), other.scalar_type()); } } - auto iter = TensorIterator::comparison_op(result, self, other, /*check_mem_overlap=*/true); + auto iter = TensorIterator::comparison_op(result, self, other); stub(iter.device_type(), iter); return result; } @@ -752,8 +763,7 @@ Tensor& logical_xor_(Tensor& self, Scalar other) { return comparison_op_(self, o Tensor& maximum_out(Tensor& result, const Tensor& self, const Tensor& other) { TORCH_CHECK(!self.is_complex() && !other.is_complex(), "maximum does not support complex inputs."); - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); maximum_stub(iter.device_type(), iter); return result; } @@ -779,8 +789,7 @@ Tensor max(const Tensor& self, const Tensor& other) { Tensor& minimum_out(Tensor& result, const Tensor& self, const Tensor& other) { TORCH_CHECK(!self.is_complex() && !other.is_complex(), "minimum does not support complex inputs."); - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); minimum_stub(iter.device_type(), iter); return result; } @@ -812,16 +821,14 @@ Tensor& floor_divide_(Tensor& self, Scalar other) { } Tensor& fmod_out(Tensor & result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); fmod_stub(iter.device_type(), iter); return result; } Tensor& fmod_out(Tensor & result, const Tensor& self, Scalar other) { - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); TORCH_CHECK(iter.device_type() == at::kCPU, "Native fmod only supports CPU"); fmod_scalar_stub(iter.device_type(), iter, other); return result; @@ -846,7 +853,7 @@ Tensor& fmod_(Tensor& self, Scalar other) { } Tensor& logaddexp_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); logaddexp_stub(iter.device_type(), iter); return result; } @@ -857,7 +864,7 @@ Tensor logaddexp(const Tensor& self, const Tensor& other) { } Tensor& logaddexp2_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, self, other); logaddexp2_stub(iter.device_type(), iter); return result; } @@ -868,7 +875,7 @@ Tensor logaddexp2(const Tensor& self, const Tensor& other) { } Tensor& gcd_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/ true); + auto iter = TensorIterator::binary_op(result, self, other); gcd_stub(iter.device_type(), iter); return result; } @@ -883,7 +890,7 @@ Tensor& gcd_(Tensor& self, const Tensor& other) { } Tensor& lcm_out(Tensor& result, const Tensor& self, const Tensor& other) { - auto iter = TensorIterator::binary_op(result, self, other, /*check_mem_overlap=*/ true); + auto iter = TensorIterator::binary_op(result, self, other); lcm_stub(iter.device_type(), iter); return result; } diff --git a/aten/src/ATen/native/Blas.cpp b/aten/src/ATen/native/Blas.cpp index fc3c7d637a3ee89..606b381ca83b5d4 100644 --- a/aten/src/ATen/native/Blas.cpp +++ b/aten/src/ATen/native/Blas.cpp @@ -54,6 +54,8 @@ Tensor &addmv_out(Tensor& result, const Tensor &self, const Tensor &mat, const T "size mismatch, get ", self_.size(0), ", ", mat.size(0), "x", mat.size(1), ",", vec.size(0)); if (mat.numel() == 0) { + // By definition, when beta==0, values in self should be ignored. nans and infs + // should not propagate if (beta.toDouble() == 0.0) { result.zero_(); } else { diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 158c8d9d8aa035a..212dd15eef33587 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -268,7 +268,8 @@ auto ConvParams::use_xnnpack( padding, stride, dilation, - groups); + groups, + transposed); } #endif return false; diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index 736c65b6e8bfc7e..d02dc781ed19a55 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -29,7 +29,12 @@ Tensor& fill_out(Tensor& self, Scalar value) { fill_fast(self, value);}); return self; } - auto iter = TensorIterator::nullary_op(self); + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) // Fill is idempotent, so overlap is okay + .check_all_same_dtype(false) + .add_output(self) + .resize_outputs(false) + .build(); fill_stub(iter.device_type(), iter, value); return self; } diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index dcf74e89a92f01e..7caa6218b63aa6a 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -1,14 +1,24 @@ #include +#include + namespace at { namespace native { -std::vector foreach_add_scalar_kernel_fallback(TensorList tensors, Scalar scalar) { - TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); +std::vector foreach_tensor_add_scalar_kernel_slow(TensorList tensors, Scalar scalar) { + verify_list(tensors); std::vector result; - for (int i = 0; i < tensors.size(); i++) { - auto temp = tensors[i].add(scalar); - result.emplace_back(temp); + for (const auto& t : tensors) { + result.emplace_back(t.add(scalar)); } return result; } + +void foreach_tensor_add_scalar_kernel_slow_(TensorList tensors, Scalar scalar) { + verify_list(tensors); + + for (auto& t : tensors) { + t.add_(scalar); + } +} + }} // namespace at::native diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h new file mode 100644 index 000000000000000..1afab4c6ba2860b --- /dev/null +++ b/aten/src/ATen/native/ForeachUtils.h @@ -0,0 +1,58 @@ +#pragma once +#include + +namespace at { +namespace native { +namespace { + +void verify_list(TensorList tensors) { + TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); + auto expected_dtype = tensors[0].dtype(); + + for (auto t : tensors) { + TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); + } +} + +// To go via 'fast' path, several conditions must be satisfied +// - All tensors must have strided layout +// - All tensors must be non-overlapping and dense +// - Resulting tensor must have the same dtype as the input one +bool check_fast_route(TensorList tensors, Scalar scalar) { + TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); + auto expected_device = tensors[0].device(); + + for (auto t : tensors) { + if (t.device() != expected_device) { + return false; + } + + if (t.layout() != at::kStrided) { + return false; + } + + if (!t.is_non_overlapping_and_dense()) { + return false; + } + + // complex scalar + integral or boolean tensor will result in complex tensor + if (scalar.isComplex() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) { + return false; + } + + // float scalar + integral or boolean tensor will result in float tensor + if (scalar.isFloatingPoint() && at::isIntegralType(t.scalar_type(), /*includeBool*/ true)) { + return false; + } + + // integral scalar + boolean tensor will result in integral tensor + if (scalar.isIntegral(/*includeBool*/ false) && t.dtype() == at::kBool) { + return false; + } + } + + return true; +} + +} // namespace +}} // at::native diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1f71a3cc99c8d13..a692ca0a800e7a0 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -1284,10 +1285,13 @@ Tensor frobenius_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { if (dim.size() == 1 || dim.size() == 0) { return at::norm(self, 2, dim, keepdim); } + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); if (self.is_complex()){ - return at::sqrt(at::sum(at::real(self.conj() * self), dim, keepdim)); + return at::sqrt(at::sum(at::real(self.conj() * self), dim_, keepdim)); } else { - return at::sqrt(at::sum((self * self), dim, keepdim)); + return at::sqrt(at::sum((self * self), dim_, keepdim)); } } @@ -1305,10 +1309,13 @@ Tensor &frobenius_norm_out( if (dim.size() == 1 || dim.size() == 0) { return at::norm_out(result, self, 2, dim, keepdim, self.scalar_type()); } + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], "Expected dims to be different, got ", dim, " instead"); if (self.is_complex()){ - return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim, keepdim)); + return at::sqrt_out(result, at::sum(at::real(self.conj() * self), dim_, keepdim)); } else { - return at::sqrt_out(result, at::sum((self * self), dim, keepdim)); + return at::sqrt_out(result, at::sum((self * self), dim_, keepdim)); } } @@ -1342,8 +1349,10 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) { Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); - auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim()); + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); Tensor p = self.permute(permutation); // Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm @@ -1360,19 +1369,243 @@ Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) { Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) { TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); - auto permutation = create_dim_backshift_permutation(dim[0], dim[1], self.dim()); + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); auto permutation_reverse = create_reverse_permutation(permutation); Tensor p = self.permute(permutation); at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim); if (keepdim) { result.unsqueeze_(-1); - result = result.permute(permutation_reverse); + Tensor result_ = result.permute(permutation_reverse); + result.set_(result_); + } + return result; +} + +// Creates a vector of length ndim with values equal to its indices +// (e.g. [0, 1, 2, ..., ndim-1]) +static std::vector make_dim_list(int64_t ndim) { + std::vector dim_list(ndim); + for (int64_t ind = 0; ind < ndim; ind++) { + dim_list[ind] = ind; + } + return dim_list; +} + +// Checks for valid arguments to linalg_norm when type(ord) == str +static void check_str_ord_valid(const std::string& str_ord, optional opt_dim, int64_t ndim, optional opt_dtype) { + TORCH_CHECK((str_ord == "nuc") || (str_ord == "fro"), "Invalid norm order: ", str_ord); + TORCH_CHECK(!opt_dtype.has_value(), "ord=\'", str_ord, "\' does not yet support the dtype argument"); + bool dims_valid = (ndim == 2 && !opt_dim.has_value()) || (opt_dim.has_value() && opt_dim.value().size() == 2); + TORCH_CHECK(dims_valid, "order \"", str_ord, + "\" can only be used if either len(dim) == 2 or (self.dim() == 2 and dim is None)"); +} + +// Performs vector norm for ord = +/-infinity, and the second dimension reduction +// for matrix norms. +static Tensor _norm_min_max(Tensor& self, double ord, int64_t dim, bool keepdim) { + Tensor result; + if (self.numel() == 0 && self.sizes()[dim] > 0) { + // This special case is needed in matrix norm for tensors with 3 or more dims, + // or in vector norm for order inf and -inf for tesnsors with 2 or more dims. + // When the sizes of the dims to be reduced are greater than 0 but another dim + // in the tensor is size 0 (thus numel == 0), we must either flatten or resize + // the second reduction dim to 1, to avoid calling min/max, which would throw + // an error. + if (self.sizes()[dim] != 1) { + auto new_sizes = self.sizes().vec(); + new_sizes[dim] = 1; + self.resize_(new_sizes); + } + result = keepdim ? self : self.flatten(dim); + } else { + if (ord > 0) { + result = std::get<0>(self.max(dim, keepdim)); + } else { + result = std::get<0>(self.min(dim, keepdim)); + } + } + return result; +} + +// Performs matrix norm +static Tensor _linalg_norm_matrix(const Tensor &self, optional opt_ord, + IntArrayRef dim, bool keepdim, optional opt_dtype) { + Tensor result; + auto ord = opt_ord.value_or(2.0).toDouble(); + TORCH_CHECK(self.device().type() == DeviceType::CPU || self.device().type() == DeviceType::CUDA, + "matrix norm only supports CPU AND CUDA device type, got: ", self.device().type()); + TORCH_CHECK(self.layout() == Layout::Strided, + "matrix norm only supports strided layout, got: ", self.layout()); + + TORCH_CHECK(dim.size() == 2, "_linalg_norm_matrix: 'dim' must either specify 2 dimensions. ", + "Got 'dim' specifying ", dim.size(), " dims"); + auto dim_ = dim.vec(); + maybe_wrap_dims(dim_, self.dim()); + TORCH_CHECK(dim_[0] != dim_[1], + "Expected dims to be different, got (", dim[0], ", ", dim[1], ") instead"); + + ScalarType scalarType = opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type(); + TORCH_CHECK( + at::isFloatingType(scalarType) || at::isComplexType(scalarType), + "Can only calculate the mean of floating and complex types. Got ", + toString(scalarType), " instead."); + + Tensor self_; + if (opt_dtype.has_value()) { + self_ = self.to(scalarType); + } else { + self_ = self; + } + + if (std::abs(ord) == 2) { + // Need to shift the reduction dims to the back, because at::svd will only operate on + // the last 2 dimensions + auto permutation = create_dim_backshift_permutation(dim_[0], dim_[1], self.dim()); + auto permutation_reverse = create_reverse_permutation(permutation); + + result = std::get<1>(self_.permute(permutation).svd()).abs(); + result = _norm_min_max(result, ord, result.dim() - 1, keepdim); + + if (keepdim) { + result.unsqueeze_(-1); + result = result.permute(permutation_reverse); + } + } else { + // abs(p) == infinity and abs(p) == 1 will perform identical reductions, except + // that the order of the two dims is swapped. So we can swap the dims if + // abs(p) == infinity to simplify the rest of the operation's logic. + if (std::abs(ord) == INFINITY) { + std::swap(dim_[0], dim_[1]); + } + // If the dim of the second reduction is greater than that of the first reduction + // and we are not keeping the dims, then the fact that the output of the first + // reduction will have one fewer dimension means that the second reduction dim + // will be off by one, so we need to correct that. + if ((dim_[1] > dim_[0]) && !keepdim) { + dim_[1]--; + } + if (std::abs(ord) == 1 || std::abs(ord) == INFINITY) { + result = self_.abs().sum(dim_[0], keepdim); + result = _norm_min_max(result, ord, dim_[1], keepdim); + } else { + TORCH_CHECK(false, "Order ", ord, " not supported for matrix norm"); + } } return result; } +// Performs vector norm +// This function mostly serves as a wrapper for at::norm, but it overrides a few cases +// for numpy compatibility. These cases are corrected within this wrapper, rather than +// in at::norm itself, to avoid breaking backward compatibility. +static Tensor _linalg_norm_vector(const Tensor& self, optional opt_ord, std::vector dim, bool keepdim, optional opt_dtype) { + if (opt_ord.has_value()) { + TORCH_INTERNAL_ASSERT(dim.size() == 1); + auto ord = opt_ord.value().toDouble(); + Tensor self_ = opt_dtype.has_value() ? self.to(opt_dtype.value()) : self; + if (std::abs(ord) == INFINITY) { + // The ord = +/-infinity case is overridden because at::norm does not match numpy + // when the input contains extreme values (like nan or +/-inf) or if the input + // size is degenerate (like size(0), size(0, N), etc) + self_ = self_.abs(); + return _norm_min_max(self_, ord, dim[0], keepdim); + } else if ((self_.numel() == 0) && (ord < 0)) { + // For negative orders with degenerate input sizes, at::norm's result does not + // match numpy. + Tensor result = self_.abs().pow(ord + 1).sum(dim[0], keepdim); + if (ord >= -1) { + // Result must be infinite in this case, and the simplest way to make that + // happen is to simply add infinity + result += INFINITY; + } else { + result = result.pow(1.0 / (ord + 1)); + } + return result; + } + } else { + // If ord == None, need to check for unique dims because at::norm does not check it + // for this case. + std::vector dim_(dim); + maybe_wrap_dims(dim_, self.dim()); + bool unique_dims = (std::unique(dim_.begin(), dim_.end())) == dim_.end(); + TORCH_CHECK(unique_dims, "Expected dims to be different, got this instead: (", dim, ")"); + } + if (opt_dtype.has_value()) { + return at::norm(self, opt_ord, dim, keepdim, opt_dtype.value()); + } else { + return at::norm(self, opt_ord, dim, keepdim); + } +} + +static Tensor& linalg_norm_out_impl(Tensor& result, const Tensor& self, optional opt_num_ord, optional opt_str_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + // Callers must give the ord argument as either a number, a string, or neither. + // Since the user-facing API has no direct control over how this function is called, this is an internal assert. + TORCH_INTERNAL_ASSERT(!(opt_num_ord.has_value() && opt_str_ord.has_value())); + if (opt_dtype.has_value()) { + auto dtype = opt_dtype.value(); + TORCH_CHECK(dtype == result.scalar_type(), "provided dtype must match dtype of result, but got", + "dtype = ", dtype, ", out.dtype = ", result.scalar_type()); + } + int64_t ndim = self.dim(); + Tensor result_; + if (opt_str_ord.has_value()) { + // 'ord' is string + auto str_ord = opt_str_ord.value(); + check_str_ord_valid(str_ord, opt_dim, ndim, opt_dtype); + if (str_ord == "fro") { + result_ = at::frobenius_norm(self, opt_dim.value_or(IntArrayRef({0, 1})), keepdim); + } else if (str_ord == "nuc") { + if (opt_dim.has_value()) { + result_ = at::nuclear_norm(self, opt_dim.value(), keepdim); + } else { + result_ = at::nuclear_norm(self, keepdim); + } + } + } else { + // 'ord' is int or None + std::vector dim_ = opt_dim.has_value() ? opt_dim.value().vec() : make_dim_list(ndim); + if (!opt_num_ord.has_value() || dim_.size() == 1) { + result_ = _linalg_norm_vector(self, opt_num_ord, dim_, keepdim, opt_dtype); + } else if (dim_.size() == 2) { + result_ = _linalg_norm_matrix(self, opt_num_ord.value(), dim_, keepdim, opt_dtype); + } else { + TORCH_CHECK(false, "'dim' must specify 1 or 2 dimensions when order is numerical and input is " + "not 1-D or 2-D"); + } + } + resize_output(result, result_.sizes()); + result.copy_(result_); + return result; +} + +// Numerical or None norms +Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + Tensor result = at::empty({0}, options); + return at::native::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +// Frobenius and nuclear norms +Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + auto options = TensorOptions().dtype(opt_dtype.has_value() ? opt_dtype.value() : self.scalar_type()).device(self.device()); + Tensor result = at::empty({0}, options); + return at::native::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + +// Numerical or None norms +Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return linalg_norm_out_impl(result, self, opt_ord, c10::nullopt, opt_dim, keepdim, opt_dtype); +} + +// Frobenius and nuclear norms +Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return linalg_norm_out_impl(result, self, c10::nullopt, ord, opt_dim, keepdim, opt_dtype); +} + static inline Tensor _chain_matmul_general(TensorList matrices, std::vector>& order, int64_t i, int64_t j) { if (i == j) return matrices[i]; diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index 5587e013949ec5d..414c8a6f6390413 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -11,8 +11,7 @@ DEFINE_DISPATCH(pow_tensor_tensor_stub); DEFINE_DISPATCH(pow_tensor_scalar_stub); Tensor& pow_out(Tensor& result, const Tensor& base, const Tensor& exp) { - auto iter = TensorIterator::binary_op(result, base, exp, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(result, base, exp); pow_tensor_tensor_stub(iter.device_type(), iter); return result; } @@ -37,8 +36,7 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { } else if (!exp.isComplex() && (exp.toDouble() == 1.0)) { result.resize_as_(base).copy_(base); } else { - auto iter = TensorIterator::unary_op(result, base.to(common_dtype), - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, base.to(common_dtype)); pow_tensor_scalar_stub(iter.device_type(), iter, exp); } return result; diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 80d97a7e8a0c58b..5e66819fb98e203 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -1858,80 +1858,58 @@ static auto cell_params_base_registry = return cell_params_deserializers[type](std::move(state)); }); -static auto registry = - torch::RegisterOperators() - .op("aten::quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel( - DispatchKey::CPUTensorId)) - .op("aten::quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel( - DispatchKey::CPUTensorId)) - .op("aten::quantized_lstm.input_legacy(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_lstm_input_legacy), - quantized_lstm_input_legacy>(DispatchKey::CPUTensorId)) - .op("aten::quantized_lstm.data_legacy(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_lstm_data_legacy), - quantized_lstm_data_legacy>(DispatchKey::CPUTensorId)) - .op("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase", - torch::RegisterOperators::options() - .kernel< - decltype(make_quantized_cell_params_dynamic), - make_quantized_cell_params_dynamic>( - DispatchKey::CPUTensorId)) - .op("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase", - torch::RegisterOperators::options() - .catchAllKernel< - decltype(make_quantized_cell_params_fp16), - &make_quantized_cell_params_fp16>()) - .op("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase", - torch::RegisterOperators::options() - .kernel< - decltype(make_quantized_cell_params), - make_quantized_cell_params>(DispatchKey::CPUTensorId)) - .op("aten::quantized_gru.input(Tensor input, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel( - DispatchKey::CPUTensorId)) - .op("aten::quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel( - DispatchKey::CPUTensorId)) - .op("aten::quantized_gru.input_legacy(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_gru_input_legacy), - quantized_gru_input_legacy>(DispatchKey::CPUTensorId)) - .op("aten::quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_gru_data_legacy), - quantized_gru_data_legacy>(DispatchKey::CPUTensorId)) - .op("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_lstm_cell_dynamic), - quantized_lstm_cell_dynamic>(DispatchKey::CPUTensorId)) - .op("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_gru_cell_dynamic), - quantized_gru_cell_dynamic>(DispatchKey::CPUTensorId)) - .op("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_rnn_relu_cell_dynamic), - quantized_rnn_relu_cell_dynamic>(DispatchKey::CPUTensorId)) - .op("quantized::quantized_rnn_tanh_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor", - torch::RegisterOperators::options() - .kernel< - decltype(quantized_rnn_tanh_cell_dynamic), - quantized_rnn_tanh_cell_dynamic>(DispatchKey::CPUTensorId)); +TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(aten, m) { + m.def( + "quantized_lstm.input(Tensor input, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "quantized_lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "quantized_lstm.input_legacy(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "quantized_lstm.data_legacy(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, *, ScalarType? dtype=None, bool use_dynamic=False) -> (Tensor, Tensor, Tensor)"); + m.def( + "quantized_gru.input(Tensor input, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)"); + m.def( + "quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, __torch__.torch.classes.rnn.CellParamsBase[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)"); + m.def( + "quantized_gru.input_legacy(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)"); + m.def( + "quantized_gru.data_legacy(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)"); +} + +TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(quantized, m) { + m.def("make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"); + m.def("make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"); + m.def("make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"); + m.def("quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"); + m.def("quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"); + m.def("quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"); + m.def("quantized_rnn_tanh_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(aten, CPU, m) { + m.impl("quantized_lstm.input", TORCH_FN(quantized_lstm_input)); + m.impl("quantized_lstm.data", TORCH_FN(quantized_lstm_data)); + m.impl("quantized_lstm.input_legacy", TORCH_FN(quantized_lstm_input_legacy)); + m.impl("quantized_lstm.data_legacy", TORCH_FN(quantized_lstm_data_legacy)); + m.impl("quantized_gru.input", TORCH_FN(quantized_gru_input)); + m.impl("quantized_gru.data", TORCH_FN(quantized_gru_data)); + m.impl("quantized_gru.input_legacy", TORCH_FN(quantized_gru_input_legacy)); + m.impl("quantized_gru.data_legacy", TORCH_FN(quantized_gru_data_legacy)); +} + +TORCH_LIBRARY_IMPL(quantized, CPU, m) { + m.impl("make_quantized_cell_params_dynamic", TORCH_FN(make_quantized_cell_params_dynamic)); + m.impl("make_quantized_cell_params", TORCH_FN(make_quantized_cell_params)); + m.impl("quantized_lstm_cell_dynamic", TORCH_FN(quantized_lstm_cell_dynamic)); + m.impl("quantized_gru_cell_dynamic", TORCH_FN(quantized_gru_cell_dynamic)); + m.impl("quantized_rnn_relu_cell_dynamic", TORCH_FN(quantized_rnn_relu_cell_dynamic)); + m.impl("quantized_rnn_tanh_cell_dynamic", TORCH_FN(quantized_rnn_tanh_cell_dynamic)); +} + +TORCH_LIBRARY_IMPL(quantized, CatchAll, m) { + m.impl("make_quantized_cell_params_fp16", TORCH_FN(make_quantized_cell_params_fp16)); +} } // namespace }} // namespace at::native diff --git a/aten/src/ATen/native/RangeFactories.cpp b/aten/src/ATen/native/RangeFactories.cpp index 8deff003be12282..c2fdde8c58aeea1 100644 --- a/aten/src/ATen/native/RangeFactories.cpp +++ b/aten/src/ATen/native/RangeFactories.cpp @@ -173,7 +173,7 @@ Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) { } Tensor r = result.is_contiguous() ? result : result.contiguous(); - auto iter = TensorIterator::nullary_op(r, /*check_mem_overlap=*/true); + auto iter = TensorIterator::nullary_op(r); arange_stub(iter.device_type(), iter, start, size, step); if (!result.is_contiguous()) { result.copy_(r); diff --git a/aten/src/ATen/native/RowwisePrune.cpp b/aten/src/ATen/native/RowwisePrune.cpp new file mode 100644 index 000000000000000..c373f3a64701697 --- /dev/null +++ b/aten/src/ATen/native/RowwisePrune.cpp @@ -0,0 +1,106 @@ +// Copyright 2004-present Facebook. All Rights Reserved. + +#include + + +namespace at { +namespace native { + +namespace { + +template +std::tuple _rowwise_prune_helper( + const Tensor& weights, const Tensor& mask, + ScalarType compressed_indices_dtype) { + int num_non_masked_rows = 0; + auto mask_contig = mask.contiguous(); + auto mask_data = mask_contig.data_ptr(); + for (int i = 0; i < mask.numel(); ++i) { + num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0); + } + int num_cols = weights.size(1); + auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols}, + weights.options()); + auto compressed_indices_mapping = at::empty({mask.numel()}, + compressed_indices_dtype); + AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, + at::ScalarType::BFloat16, + weights.scalar_type(), + "rowwise_prune_helper", [&]() { + auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr(); + auto compressed_indices_mapping_data = + compressed_indices_mapping.data_ptr(); + auto weights_data = weights.data_ptr(); + int last_row_kept = 0; + for (int i = 0; i < mask.numel(); i++) { + if (mask_data[i]) { + memcpy(pruned_2d_tensor_data + last_row_kept * num_cols, + weights_data + i * num_cols, + num_cols * sizeof (scalar_t)); + compressed_indices_mapping_data[i] = last_row_kept; + last_row_kept++; + } else { + compressed_indices_mapping_data[i] = -1; + } + } + }); + return std::tuple(pruned_2d_tensor, + compressed_indices_mapping); +} + +} // namespace + + +// This operator introduces sparsity to the 'weights' matrix with the help +// of the importance indicator 'mask'. +// +// A row is considered important and not pruned if the mask value for that +// particular row is 1(True) and not important otherwise. +// +// This operator doesn't zero out the pruned rows in-place. Instead, it +// returns a tuple that contains a pruned weights tensor as well as a map that +// can be used to look up the original row in the pruned weights tensor. +// We refer this map as 'compressed indices map' going forward. + +// The 'compressed indices map' is an 1D tensor that contains one entry per +// original row in 'weights'. The array index is the index for the original +// non-pruned weight tensor and the value would be the re-mapped index in the +// pruned weights tensor. If the value for a index is -1, it means the +// corresponding row has been pruned from the original weight tensor. + +// Arguments: +// 'weights' - two dimensional matrix that needs to be prune. +// 'mask' - 1D boolean tensor that represents whether a row is important or +// not. A mask value of 1 means the row should be kept and 0 means the row +// should be pruned. +// +// Returns: +// A tuple containing two tensors, +// 1. A pruned weight tensor that contains only the weights that are preserved +// post pruning. +// 2. An 1D tensor that contains the mapping between original weight row and +// the corresponding row in the pruned weights tensor. +std::tuple rowwise_prune(const Tensor& weights, + const Tensor& mask, + ScalarType compressed_indices_dtype) { + TORCH_CHECK(weights.ndimension() == 2, + "'weights' should have 2 dimensions."); + TORCH_CHECK( + mask.numel() == weights.size(0), + "Number of elements in 'mask' should be equivalent to the " + "number of rows in 'weights'." + ) + TORCH_CHECK( + compressed_indices_dtype == ScalarType::Int || + compressed_indices_dtype == ScalarType::Long, + "compressed_indices_dtype should be either int(int32) or long(int64)."); + + if (compressed_indices_dtype == at::ScalarType::Int) { + return _rowwise_prune_helper(weights, mask, + compressed_indices_dtype); + } + return _rowwise_prune_helper(weights, mask, + compressed_indices_dtype); +} + +}} // namesapce at::native diff --git a/aten/src/ATen/native/SharedReduceOps.h b/aten/src/ATen/native/SharedReduceOps.h index a1f324996698585..765c9937277fe82 100644 --- a/aten/src/ATen/native/SharedReduceOps.h +++ b/aten/src/ATen/native/SharedReduceOps.h @@ -344,15 +344,29 @@ namespace detail { template struct LessOrNan { - C10_DEVICE bool operator () (scalar_t a, scalar_t b) const { - return at::_isnan(a) || a < b; + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); } }; template struct GreaterOrNan { - C10_DEVICE bool operator () (scalar_t a, scalar_t b) const { - return at::_isnan(a) || a > b; + C10_DEVICE bool operator () (scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); } }; @@ -367,11 +381,11 @@ struct MinMaxReductionOps { } static C10_DEVICE arg_t reduce(arg_t arg, scalar_t val, int64_t idx) { - return comp_t{}(arg.first, val) ? arg : arg_t(val, idx); + return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx); } static C10_DEVICE arg_t combine(arg_t a, arg_t b) { - return comp_t{}(a.first, b.first) ? a : b; + return comp_t{}(a.first, b.first, a.second, b.second) ? a : b; } static C10_DEVICE arg_t translate_idx(arg_t a, int64_t base_idx) { diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 09c2e2261d2af36..3d333923c8b9c51 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -54,6 +54,7 @@ #include #include #include +#include #include #include #include @@ -353,6 +354,10 @@ Tensor& index_add_cpu_(Tensor & self, int64_t dim, const Tensor & index, const T TORCH_CHECK(numel == (source.dim() == 0 ? 1 : source.size(dim)), "index_add_(): Number of indices should be equal to self.size(dim)"); + at::assert_no_internal_overlap(self); + at::assert_no_partial_overlap(self, index); + at::assert_no_partial_overlap(self, source); + auto index_contig = index.contiguous(); auto index_data = index_contig.data_ptr(); diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp index 8ad277264681a70..2eca94b36b37d48 100644 --- a/aten/src/ATen/native/TensorIterator.cpp +++ b/aten/src/ATen/native/TensorIterator.cpp @@ -816,9 +816,9 @@ void TensorIterator::select_all_keeping_dim(int start_dim, IntArrayRef indices) } TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, - const Tensor& b, bool check_mem_overlap) { + const Tensor& b) { return TensorIteratorConfig() - .set_check_mem_overlap(check_mem_overlap) + .set_check_mem_overlap(true) .add_output(out) .add_input(a) .add_input(b) @@ -830,9 +830,9 @@ TensorIterator TensorIterator::binary_op(Tensor& out, const Tensor& a, } TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, - const Tensor& b, bool check_mem_overlap) { + const Tensor& b) { return TensorIteratorConfig() - .set_check_mem_overlap(check_mem_overlap) + .set_check_mem_overlap(true) .add_output(out) .add_input(a) .add_input(b) @@ -841,10 +841,9 @@ TensorIterator TensorIterator::comparison_op(Tensor& out, const Tensor& a, .build(); } -TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a, - bool check_mem_overlap) { +TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a) { return TensorIteratorConfig() - .set_check_mem_overlap(check_mem_overlap) + .set_check_mem_overlap(true) .add_output(out) .add_input(a) .cast_common_dtype_to_outputs(false) @@ -853,10 +852,10 @@ TensorIterator TensorIterator::unary_op(Tensor& out, const Tensor& a, .build(); } -TensorIterator TensorIterator::nullary_op(Tensor& out, bool check_mem_overlap) { +TensorIterator TensorIterator::nullary_op(Tensor& out) { return TensorIteratorConfig() + .set_check_mem_overlap(true) .check_all_same_dtype(false) - .set_check_mem_overlap(check_mem_overlap) .add_output(out) // FIXME: workaround for bug: https://github.com/pytorch/pytorch/issues/20342 .resize_outputs(false) diff --git a/aten/src/ATen/native/TensorIterator.h b/aten/src/ATen/native/TensorIterator.h index 9f58347c6a0bdcf..27ef0b8eda837b3 100644 --- a/aten/src/ATen/native/TensorIterator.h +++ b/aten/src/ATen/native/TensorIterator.h @@ -156,14 +156,10 @@ struct CAFFE2_API TensorIterator { void foreach_reduced_elt(loop_subiter_t loop, bool parallelize=true); - static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b, - bool check_mem_overlap = false); - static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b, - bool check_mem_overlap = false); - static TensorIterator unary_op(Tensor& out, const Tensor& a, - bool check_mem_overlap = false); - static TensorIterator nullary_op(Tensor& out, - bool check_mem_overlap = false); + static TensorIterator binary_op(Tensor& out, const Tensor& a, const Tensor& b); + static TensorIterator comparison_op(Tensor& out, const Tensor& a, const Tensor& b); + static TensorIterator unary_op(Tensor& out, const Tensor& a); + static TensorIterator nullary_op(Tensor& out); static TensorIterator reduce_op(Tensor& out, const Tensor& a); static TensorIterator reduce_op(Tensor& out1, Tensor& out2, const Tensor& a); diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 337ea73393ec68f..48dab43b2dc8132 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -53,20 +53,14 @@ bool cudnn_is_acceptable(const Tensor& self) { } Tensor detach(const Tensor& self) { -#ifndef USE_STATIC_DISPATCH // this just exists to give us a hook in VariableType and an entry in Declarations.yaml //AT_ERROR("detach is not implemented for Tensor"); -#endif - // this is no-op for USE_STATIC_DISPATCH mode return self; } Tensor & detach_(Tensor & self) { -#ifndef USE_STATIC_DISPATCH // this just exists to give us a hook in VariableType and an entry in Declarations.yaml //AT_ERROR("detach_ is not implemented for Tensor"); -#endif - // this is no-op for USE_STATIC_DISPATCH mode return self; } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 50203aef98a5d20..73ad9840b0bead1 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -39,8 +39,7 @@ namespace native { // operators (more is foreseeable) and is more flexible and elegant than the latter. template static inline Tensor& unary_op_impl_out(Tensor& result, const Tensor& self, Stub& stub) { - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); stub(iter.device_type(), iter); return result; } @@ -61,8 +60,7 @@ static inline Tensor& unary_op_impl_with_complex_to_float_out(Tensor& result, co // Runs the function complex->complex, as TensorIterator expects Tensor complex_result = at::empty({0}, self.options()); - auto iter = TensorIterator::unary_op(complex_result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(complex_result, self); stub(iter.device_type(), iter); // Copies the complex result to the actual result and returns it @@ -108,9 +106,9 @@ Tensor acos(const Tensor& self) { return unary_op_impl(self, at::acos_out); } Tensor& acos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); } // arccos, alias for acos -Tensor& arccos_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, acos_stub); } -Tensor arccos(const Tensor& self) { return unary_op_impl(self, at::acos_out); } -Tensor& arccos_(Tensor& self) { return unary_op_impl_(self, at::acos_out); } +Tensor& arccos_out(Tensor& result, const Tensor& self) { return at::acos_out(result, self); } +Tensor arccos(const Tensor& self) { return self.acos(); } +Tensor& arccos_(Tensor& self) { return self.acos_(); } static Tensor wrapped_scalar_tensor(Scalar scalar) { auto tensor = scalar_to_tensor(scalar); @@ -140,18 +138,18 @@ Tensor asin(const Tensor& self) { return unary_op_impl(self, at::asin_out); } Tensor& asin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); } // arcsin, alias of asin -Tensor& arcsin_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, asin_stub); } -Tensor arcsin(const Tensor& self) { return unary_op_impl(self, at::asin_out); } -Tensor& arcsin_(Tensor& self) { return unary_op_impl_(self, at::asin_out); } +Tensor& arcsin_out(Tensor& result, const Tensor& self) { return at::asin_out(result, self); } +Tensor arcsin(const Tensor& self) { return self.asin(); } +Tensor& arcsin_(Tensor& self) { return self.asin_(); } Tensor& atan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); } Tensor atan(const Tensor& self) { return unary_op_impl(self, at::atan_out); } Tensor& atan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); } // arctan, alias of atan -Tensor& arctan_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atan_stub); } -Tensor arctan(const Tensor& self) { return unary_op_impl(self, at::atan_out); } -Tensor& arctan_(Tensor& self) { return unary_op_impl_(self, at::atan_out); } +Tensor& arctan_out(Tensor& result, const Tensor& self) { return at::atan_out(result, self); } +Tensor arctan(const Tensor& self) { return self.atan(); } +Tensor& arctan_(Tensor& self) { return self.atan_(); } // Note [Complex abs and angle] // Complex inputs to abs and angle return float results by default. @@ -314,10 +312,20 @@ Tensor& asinh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out Tensor asinh(const Tensor& self) { return unary_op_impl(self, at::asinh_out); } Tensor& asinh_(Tensor& self) { return unary_op_impl_(self, at::asinh_out); } +// arcsinh, alias for asinh +Tensor& arcsinh_out(Tensor& result, const Tensor& self) { return at::asinh_out(result, self); } +Tensor arcsinh(const Tensor& self) { return self.asinh(); } +Tensor& arcsinh_(Tensor& self) { return self.asinh_(); } + Tensor& atanh_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, atanh_stub); } Tensor atanh(const Tensor& self) { return unary_op_impl(self, at::atanh_out); } Tensor& atanh_(Tensor& self) { return unary_op_impl_(self, at::atanh_out); } +// arctanh, alias for atanh +Tensor& arctanh_out(Tensor& result, const Tensor& self) { return at::atanh_out(result, self); } +Tensor arctanh(const Tensor& self) { return self.atanh(); } +Tensor& arctanh_(Tensor& self) { return self.atanh_(); } + Tensor& sqrt_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, sqrt_stub); } Tensor sqrt(const Tensor& self) { return unary_op_impl(self, at::sqrt_out); } Tensor& sqrt_(Tensor& self) { return unary_op_impl_(self, at::sqrt_out); } @@ -333,10 +341,7 @@ Tensor& logit_out( Tensor& result, const Tensor& self, c10::optional eps) { - auto iter = TensorIterator::unary_op( - result, - self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); logit_stub(iter.device_type(), iter, Scalar(eps ? eps.value() : -1.0)); return result; } @@ -435,8 +440,7 @@ Tensor& clamp_out(Tensor& result, const Tensor& self, optional min, opti if (min && max) { TORCH_CHECK(self.layout() == Layout::Strided, "clamp only supports strided layout, got: ", self.layout()); - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); clamp_stub(iter.device_type(), iter, *min, *max); } else if (max) { at::clamp_max_out(result, self, *max); @@ -461,8 +465,7 @@ Tensor& clamp_max_out(Tensor& result, const Tensor& self, Scalar max) { TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_max only supports strided layout, got: ", self.layout()); - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); clamp_max_stub(iter.device_type(), iter, max); return result; } @@ -480,8 +483,7 @@ Tensor& clamp_min_out(Tensor& result, const Tensor& self, Scalar min) { TORCH_CHECK(!self.is_complex(), "clamp is not yet implemented for complex tensors."); TORCH_CHECK(self.layout() == Layout::Strided, "clamp_min only supports strided layout, got: ", self.layout()); - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); clamp_min_stub(iter.device_type(), iter, min); return result; } @@ -518,8 +520,7 @@ Tensor& polygamma_(Tensor& self, int64_t n) { } Tensor& polygamma_out(Tensor& result, int64_t n, const Tensor& self) { TORCH_CHECK(n >= 0, "polygamma(n, x) does not support negative n."); - auto iter = TensorIterator::unary_op(result, self, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::unary_op(result, self); polygamma_stub(iter.device_type(), iter, n); return result; } @@ -563,8 +564,7 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \ checkDeviceType(#op, result, DeviceType::device); \ checkLayout(#op, result, Layout::Strided); \ - auto iter = TensorIterator::unary_op(result, self, \ - /*check_mem_overlap=*/true); \ + auto iter = TensorIterator::unary_op(result, self); \ op##_stub(iter.device_type(), iter); \ return result; \ } diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 946c0a7276150a9..133d51bd99a89ec 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -318,14 +318,14 @@ void rshift_kernel(TensorIterator& iter) { void lt_kernel(TensorIterator& iter) { if (iter.dtype() == ScalarType::Bool) { - AT_DISPATCH_ALL_TYPES_AND2(kBool, kBFloat16, iter.input_dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, iter.input_dtype(), "lt_cpu", [&]() { cpu_kernel(iter, [](scalar_t a, scalar_t b) -> bool { return a < b; }); }); } else { - AT_DISPATCH_ALL_TYPES_AND(kBFloat16, iter.dtype(), "lt_cpu", [&]() { + AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, iter.dtype(), "lt_cpu", [&]() { cpu_kernel_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { diff --git a/aten/src/ATen/native/cpu/FillKernel.cpp b/aten/src/ATen/native/cpu/FillKernel.cpp index be9f1a445d61dce..69a304257fca777 100644 --- a/aten/src/ATen/native/cpu/FillKernel.cpp +++ b/aten/src/ATen/native/cpu/FillKernel.cpp @@ -12,7 +12,7 @@ namespace { template -static void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) { +void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) { auto value = value_scalar.to().x; using H = typename std::make_signed::type; // Signed type has more acceleration // Reserve the representation of value. static_cast(value) is implementation defined. @@ -23,11 +23,24 @@ static void fill_non_native_type(TensorIterator& iter, Scalar value_scalar) { [val]() { return Vec256(val); }); } +template <> +void fill_non_native_type>(TensorIterator& iter, Scalar value_scalar) { + static_assert(sizeof(c10::complex) == sizeof(int32_t), "Size of ComplexHalf should be 32-bits"); + auto value = c10::complex(value_scalar.to>()); + auto val = *reinterpret_cast(std::addressof(value)); + cpu_kernel_vec( + iter, + [val]() -> int32_t { return val; }, + [val]() { return Vec256(val); }); +} + void fill_kernel(TensorIterator& iter, Scalar value_scalar) { if (iter.dtype() == ScalarType::Half) { fill_non_native_type(iter, value_scalar); } else if (iter.dtype() == ScalarType::BFloat16) { fill_non_native_type(iter, value_scalar); + } else if (iter.dtype() == ScalarType::ComplexHalf) { + fill_non_native_type>(iter, value_scalar); } else { AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(at::ScalarType::Bool, iter.dtype(), "fill_cpu", [&]() { scalar_t value = value_scalar.to(); diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 840b6ce5d75d707..9ad80ec5daeb614 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -98,7 +98,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef } void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(), "index_cpu", [&] { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { *(scalar_t*)dst = *(scalar_t*)(src + offset); @@ -108,11 +108,11 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) { // NOTE: duplicate indices are only supported if accumulate is true. - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(), "index_put", [&] { if (accumulate) { bool use_parallel_for = ((iter.numel() >= internal::GRAIN_SIZE) && (at::get_num_threads() > 1)); - if (iter.dtype() == at::ScalarType::Float && use_parallel_for) { + if (iter.dtype() == ScalarType::Float && use_parallel_for) { cpu_index_kernel(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) { cpu_atomic_add_float((float*)(dst + offset), *(float*)src); }); @@ -151,11 +151,11 @@ void cpu_masked_fill_kernel(TensorIterator& iter, scalar_t value) { } void masked_fill_kernel(TensorIterator& iter, Scalar value) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "masked_fill", [&] { scalar_t scalar_val = value.to(); auto mask_dtype = iter.input_dtype(0); - if (mask_dtype == at::ScalarType::Bool) { + if (mask_dtype == ScalarType::Bool) { cpu_masked_fill_kernel(iter, scalar_val); } else { cpu_masked_fill_kernel(iter, scalar_val); @@ -187,10 +187,10 @@ void cpu_masked_select_serial_kernel(TensorIterator& iter, const func_t& f) { } void masked_select_serial_kernel(TensorIterator& iter, int64_t result_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "masked_select", [&] { auto mask_dtype = iter.input_dtype(1); - if (mask_dtype == at::ScalarType::Bool) { + if (mask_dtype == ScalarType::Bool) { cpu_masked_select_serial_kernel(iter, [result_stride](char* dst, char* src, int64_t offset) { *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src; }); @@ -226,10 +226,10 @@ void cpu_masked_select_kernel(TensorIterator& iter, const func_t& f) { } void masked_select_kernel(TensorIterator& iter, int64_t result_stride) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::ScalarType::Bool, at::ScalarType::BFloat16, + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, iter.dtype(), "masked_select", [&] { auto mask_dtype = iter.input_dtype(1); - if (mask_dtype == at::ScalarType::Bool) { + if (mask_dtype == ScalarType::Bool) { cpu_masked_select_kernel(iter, [result_stride](char* dst, char* src, int64_t offset) { *(scalar_t*)(dst + offset*result_stride) = *(scalar_t*)src; }); diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index dceefdd6efc061a..a22e349a9cf641b 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -15,8 +15,7 @@ static void lerp_kernel_scalar( const Tensor& end, Scalar weight) { TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype()); - auto iter = TensorIterator::binary_op(ret, self, end, - /*check_mem_overlap=*/true); + auto iter = TensorIterator::binary_op(ret, self, end); AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(ret.scalar_type(), "lerp_kernel_scalar", [&] { using value_t = typename c10::scalar_value_type::type; scalar_t weight_val = weight.to(); diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 40938b97e914f68..cce011a5a3127b2 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -292,7 +292,7 @@ static void argmax_kernel_impl(TensorIterator &iter) { binary_kernel_reduce( iter, ArgMaxOps{}, - std::pair(lower_bound(), -1)); + std::pair(lower_bound(), 0)); }); } @@ -301,7 +301,7 @@ static void argmin_kernel_impl(TensorIterator &iter) { binary_kernel_reduce( iter, ArgMinOps{}, - std::pair(upper_bound(), -1)); + std::pair(upper_bound(), 0)); }); } diff --git a/aten/src/ATen/native/cuda/Activation.cu b/aten/src/ATen/native/cuda/Activation.cu index 359fd5b7bf3e214..b8af7155cf4838f 100644 --- a/aten/src/ATen/native/cuda/Activation.cu +++ b/aten/src/ATen/native/cuda/Activation.cu @@ -29,10 +29,7 @@ void prelu_cuda_kernel_share_weights( Tensor& result, const scalar_t* weight_data) { - at::TensorIterator iter = TensorIteratorConfig() - .add_output(result) - .add_input(input) - .build(); + auto iter = TensorIterator::unary_op(result, input); at::native::gpu_kernel(iter, [weight_data] GPU_LAMBDA (scalar_t input_val) { @@ -540,7 +537,16 @@ static Tensor threshold_out_cuda( Scalar value, const Tensor& other) { Tensor result = opt_result.value_or(Tensor()); - auto iter = TensorIterator::binary_op(result, self, other); + auto iter = TensorIteratorConfig() + .set_check_mem_overlap(false) // threshold is idempotent, so overlap is okay + .add_output(result) + .add_input(self) + .add_input(other) + .allow_cpu_scalars(true) + .promote_inputs_to_common_dtype(true) + .cast_common_dtype_to_outputs(true) + .enforce_safe_casting_to_output(true) + .build(); threshold_kernel(iter, threshold, value); return iter.output(); } diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh new file mode 100644 index 000000000000000..05c7c47784038d6 --- /dev/null +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -0,0 +1,126 @@ +#include +#include + +namespace at { namespace native { + +namespace { + +template +struct AddScalarFunctor_ { + __device__ void operator() ( + int chunk_size, + TensorListMetadata<1>& tl, + T scalar) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* x = (T*)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + T r_x[kILP]; + + // to make things simple, we put aligned case in a different code path + if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) { + for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { + // load + load_store(r_x, x, 0 , i_start); +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_x[ii] = static_cast(r_x[ii]) + scalar; + } + // store + load_store(x, r_x, i_start, 0); + } + } + else { + // Non-divergent exit condition for __syncthreads, not necessary here + for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_x[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if(i < n && i < chunk_size) { + r_x[ii] = x[i]; + } + } +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_x[ii] = static_cast(r_x[ii]) + scalar; + } +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if(i < n && i < chunk_size) + x[i] = r_x[ii]; + } + } + } + } +}; + +template +struct AddScalarFunctor { + __device__ void operator() ( + int chunk_size, + TensorListMetadata<2>& tl, + T scalar) { + int tensor_loc = tl.block_to_tensor[blockIdx.x]; + int chunk_idx = tl.block_to_chunk[blockIdx.x]; + int n = tl.sizes[tensor_loc]; + + T* x = (T*)tl.addresses[0][tensor_loc]; + x += chunk_idx * chunk_size; + + T* out = (T*)tl.addresses[1][tensor_loc]; + out += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + T r_x[kILP]; + T r_out[kILP]; + + // to make things simple, we put aligned case in a different code path + if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) { + for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { + // load + load_store(r_x, x, 0 , i_start); +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_out[ii] = static_cast(r_x[ii]) + scalar; + } + // store + load_store(out, r_out, i_start, 0); + } + } + else { + // Non-divergent exit condition for __syncthreads, not necessary here + for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_x[ii] = 0; + int i = i_start + threadIdx.x + ii * blockDim.x; + if(i < n && i < chunk_size) { + r_x[ii] = x[i]; + } + } +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + r_out[ii] = static_cast(r_x[ii]) + scalar; + } +#pragma unroll + for(int ii = 0; ii < kILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if(i < n && i < chunk_size) + out[i] = r_out[ii]; + } + } + } + } +}; + +} // namespace + +}} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu b/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu index 6924aef1ccbe22c..772012c608fba7a 100644 --- a/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu +++ b/aten/src/ATen/native/cuda/ForeachTensorAddScalar.cu @@ -1,81 +1,14 @@ #include -#include -#include - -// NOTE: CUDA on Windows requires that the enclosing function -// of a __device__ lambda not have internal linkage. +#include +#include namespace at { namespace native { -namespace { - -template -struct AddScalarFunctor { - __device__ void operator() ( - int chunk_size, - TensorListMetadata<2>& tl, - x_t scalar) { - int tensor_loc = tl.block_to_tensor[blockIdx.x]; - int chunk_idx = tl.block_to_chunk[blockIdx.x]; - int n = tl.sizes[tensor_loc]; - - x_t* x = (x_t*)tl.addresses[0][tensor_loc]; - x += chunk_idx * chunk_size; - - out_t* out = (out_t*)tl.addresses[1][tensor_loc]; - out += chunk_idx * chunk_size; - - n -= chunk_idx * chunk_size; - - x_t r_x[kILP]; - out_t r_out[kILP]; - - // to make things simple, we put aligned case in a different code path - if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) { - for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) { - // load - load_store(r_x, x, 0 , i_start); -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_out[ii] = static_cast(r_x[ii]) + scalar; - } - // store - load_store(out, r_out, i_start, 0); - } - } - else { - // Non-divergent exit condition for __syncthreads, not necessary here - for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) { -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_x[ii] = 0; - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) { - r_x[ii] = x[i]; - } - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - r_out[ii] = static_cast(r_x[ii]) + scalar; - } -#pragma unroll - for(int ii = 0; ii < kILP; ii++) { - int i = i_start + threadIdx.x + ii * blockDim.x; - if(i < n && i < chunk_size) - out[i] = r_out[ii]; - } - } - } - } -}; - -} // namespace - std::vector foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Scalar scalar) { - TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); + verify_list(tensors); if (!check_fast_route(tensors, scalar)) { - return at::native::foreach_add_scalar_kernel_fallback(tensors, scalar); + return at::native::foreach_tensor_add_scalar_kernel_slow(tensors, scalar); } std::vector> tensor_lists; @@ -88,9 +21,24 @@ std::vector foreach_tensor_add_scalar_kernel_cuda(TensorList tensors, Sc tensor_lists.emplace_back(std::move(vec_res)); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda", [&]() { - multi_tensor_apply<2>(tensor_lists, AddScalarFunctor(), scalar.to()); + multi_tensor_apply<2>(tensor_lists, AddScalarFunctor(), scalar.to()); }); return tensor_lists[1]; } +void foreach_tensor_add_scalar_kernel_cuda_(TensorList tensors, Scalar scalar) { + verify_list(tensors); + + if (!check_fast_route(tensors, scalar)) { + return at::native::foreach_tensor_add_scalar_kernel_slow_(tensors, scalar); + } + + std::vector> tensor_lists; + tensor_lists.emplace_back(std::move(tensors.vec())); + + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kBool, kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_tensor_add_scalar_kernel_cuda_", [&]() { + multi_tensor_apply<1>(tensor_lists, AddScalarFunctor_(), scalar.to()); + }); +} + }} // namespace at::native diff --git a/aten/src/ATen/native/cuda/ForeachUtils.cuh b/aten/src/ATen/native/cuda/ForeachUtils.cuh deleted file mode 100644 index e3800fb7f6fee06..000000000000000 --- a/aten/src/ATen/native/cuda/ForeachUtils.cuh +++ /dev/null @@ -1,56 +0,0 @@ -#pragma once -#include -#include -#include -namespace at { -namespace native { -namespace { - -static constexpr int64_t kILP = 4; -static constexpr int64_t kChunkSize = 65536; -static constexpr int64_t kBlockSize = 512; - -template -__device__ __forceinline__ bool is_aligned(T* p){ - return ((uint64_t)p) % (kILP * sizeof(T)) == 0; -} - -template -__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ - using LT = at::native::memory::aligned_vector; - ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; -} - -} - -bool check_fast_route(TensorList tensors, Scalar scalar) { - TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor."); - auto expected_dtype = tensors[0].dtype(); - auto expected_device = tensors[0].device(); - - for (auto t : tensors) { - if (t.dtype() != expected_dtype) { - return false; - } - - if (t.device() != expected_device) { - return false; - } - - if (t.layout() != at::kStrided) { - return false; - } - - if (!t.is_non_overlapping_and_dense()) { - return false; - } - - if ((at::isIntegralType(t.scalar_type(), true) && scalar.isFloatingPoint()) || - t.scalar_type() == at::kBool) { - return false; - } - } - - return true; -} -}} // at::native diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 688f035bc4a9edd..b444ca418906a11 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -449,6 +450,10 @@ Tensor& index_add_cuda_(Tensor & self, int64_t dim, const Tensor & index, const TORCH_CHECK(source.dim() <= MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); TORCH_CHECK(index.dim() <= MAX_CUTORCH_DIMS, CUTORCH_DIM_WARNING); + at::assert_no_internal_overlap(self); + at::assert_no_partial_overlap(self, index); + at::assert_no_partial_overlap(self, source); + // The `source` is partitioned into two parts: // -the size of each slice we are indexing, which is the // total size of the tensor ignoring dimension `dim`; diff --git a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp index cd3ce5cb1df80de..ddc03d5fb10669c 100644 --- a/aten/src/ATen/native/cuda/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/cuda/LegacyDefinitions.cpp @@ -3,6 +3,7 @@ #include #include #include +#include namespace at { namespace native { @@ -60,6 +61,7 @@ Tensor & masked_scatter__cuda(Tensor& self, const Tensor & mask, const Tensor & } Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, Scalar other) { + at::assert_no_internal_overlap(result); return legacy::cuda::_th_fmod_out(result, self, other); } @@ -68,6 +70,7 @@ Tensor fmod_cuda(const Tensor & self, Scalar other) { } Tensor & fmod_cuda_out(Tensor & result, const Tensor & self, const Tensor & other) { + at::assert_no_internal_overlap(result); Tensor b_self, b_other; // optimization that codegen used to do; avoids broadcast. if (other.dim() == 0) { @@ -88,6 +91,7 @@ Tensor fmod_cuda(const Tensor & self, const Tensor & other) { } Tensor & fmod_cuda_(Tensor & self, Scalar other) { + at::assert_no_internal_overlap(self); return legacy::cuda::_th_fmod_(self, other); } @@ -96,6 +100,7 @@ Tensor & fmod_cuda_(Tensor & self, const Tensor & other) { if (other.dim() == 0) { return fmod_cuda_(self, other.item()); } + at::assert_no_internal_overlap(self); Tensor b_other; std::tie(b_other) = expand_inplace(self, other, "fmod_"); return legacy::cuda::_th_fmod_(self, b_other); diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 633c7fe334f88e4..10301ed6a7c40de 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -111,6 +111,11 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma at::ScalarType scalar_type = self_.scalar_type(); if (mat1.numel() == 0) { + // By definition, when beta==0, values in self should be ignored. nans and infs + // should not propagate + if (beta.toComplexDouble() == 0.) { + return result.zero_(); + } return at::native::mul_out(result, self, at::native::scalar_tensor(beta, at::device(at::kCPU).dtype(self.scalar_type()))); } diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index 487c480881f45c0..f82a0d9a58c8ff6 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -1,12 +1,28 @@ #include #include -#include #include +#include +#include namespace at { namespace native { namespace { +static constexpr int64_t kILP = 4; +static constexpr int64_t kChunkSize = 65536; +static constexpr int64_t kBlockSize = 512; + +template +__device__ __forceinline__ bool is_aligned(T* p){ + return ((uint64_t)p) % (kILP * sizeof(T)) == 0; +} + +template +__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){ + using LT = at::native::memory::aligned_vector; + ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset]; +} + // TensorListMetadata has to be < 4KB - the limit for kernel launch argument static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; diff --git a/aten/src/ATen/native/mkldnn/BinaryOps.cpp b/aten/src/ATen/native/mkldnn/BinaryOps.cpp index e0bea2d2d1e9dab..3364fe8b335c683 100644 --- a/aten/src/ATen/native/mkldnn/BinaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/BinaryOps.cpp @@ -12,27 +12,27 @@ Tensor& mkldnn_add_out( const Tensor& self, const Tensor& other, Scalar alpha) { - AT_ERROR("mkldnn_add_out: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_add_out: ATen not compiled with MKLDNN support"); } Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) { - AT_ERROR("mkldnn_add: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_add: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) { - AT_ERROR("mkldnn_add_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_add_: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) { - AT_ERROR("mkldnn_mul_out: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_mul_out: ATen not compiled with MKLDNN support"); } Tensor mkldnn_mul(const Tensor& self, const Tensor& other) { - AT_ERROR("mkldnn_mul: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_mul: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_mul_(Tensor& self, const Tensor& other) { - AT_ERROR("mkldnn_mul_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_mul_: ATen not compiled with MKLDNN support"); } } // namespace native @@ -76,7 +76,7 @@ Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) { } Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) { - AT_ASSERTM(result.sizes() == self.sizes(), + TORCH_CHECK(result.sizes() == self.sizes(), "mkldnn_mul_out: the output size should be same as input size"); ideep::tensor& z = itensor_from_mkldnn(result); ideep::tensor& x = itensor_from_mkldnn(self); @@ -89,7 +89,7 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other) return result; } else { - AT_ASSERTM(self.sizes() == other.sizes(), + TORCH_CHECK(self.sizes() == other.sizes(), "mkldnn_mul_out: currently mkldnn not support broadcasting"); ideep::tensor y = itensor_from_mkldnn(other); ideep::binary::compute(x, y, z, dnnl::algorithm::binary_mul); diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index 0479aba4c54e877..664f7bbd8f1e404 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -6,28 +6,28 @@ namespace at { namespace native { -at::Tensor mkldnn_convolution( - const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, +Tensor mkldnn_convolution( + const Tensor& input, const Tensor& weight, const Tensor& bias, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups) { - AT_ERROR("mkldnn_convolution_forward: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_convolution_forward: ATen not compiled with MKLDNN support"); } -at::Tensor mkldnn_convolution_backward_input( - IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, +Tensor mkldnn_convolution_backward_input( + IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - AT_ERROR("mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_convolution_backward_input: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_backward_weights( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, +std::tuple mkldnn_convolution_backward_weights( + IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { - AT_ERROR("mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_convolution_backward_weights: ATen not compiled with MKLDNN support"); } -std::tuple mkldnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, +std::tuple mkldnn_convolution_backward( + const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { - AT_ERROR("mkldnn_convolution_backward: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_convolution_backward: ATen not compiled with MKLDNN support"); } }} @@ -59,9 +59,9 @@ ideep::tensor _mkldnn_convolution( const ideep::tensor& x, const ideep::tensor& w, const c10::optional& b, - at::IntArrayRef padding, - at::IntArrayRef stride, - at::IntArrayRef dilation, + IntArrayRef padding, + IntArrayRef stride, + IntArrayRef dilation, int64_t groups) { auto kernel_size = w.get_dims(); @@ -98,10 +98,10 @@ ideep::tensor _mkldnn_convolution( return y; } -at::Tensor mkldnn_convolution( - const at::Tensor& input, - const at::Tensor& weight, - const at::Tensor& bias, +Tensor mkldnn_convolution( + const Tensor& input, + const Tensor& weight, + const Tensor& bias, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, @@ -131,7 +131,7 @@ at::Tensor mkldnn_convolution( } Tensor mkldnn_convolution_backward_input( - IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, + IntArrayRef input_size, const Tensor& grad_output, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { auto mkldnn_grad_output = get_mkldnn_tensor(grad_output); @@ -153,8 +153,8 @@ Tensor mkldnn_convolution_backward_input( grad_output.options())); } -std::tuple mkldnn_convolution_backward_weights( - IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, +std::tuple mkldnn_convolution_backward_weights( + IntArrayRef weight_size, const Tensor& grad_output, const Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) { const ideep::tensor mkldnn_grad_output = get_mkldnn_tensor(grad_output); @@ -193,8 +193,8 @@ std::tuple mkldnn_convolution_backward_weights( grad_output.options()))); } -std::tuple mkldnn_convolution_backward( - const at::Tensor& input, const at::Tensor& grad_output_t, const at::Tensor& weight, +std::tuple mkldnn_convolution_backward( + const Tensor& input, const Tensor& grad_output_t, const Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, std::array output_mask) { Tensor grad_output = grad_output_t.contiguous(); diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index 982d55a7081c868..20479ba9164f83a 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -11,7 +11,7 @@ Tensor mkldnn_linear( const Tensor& self, const Tensor& weight, const Tensor& bias) { - AT_ERROR("mkldnn_linear: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_linear: ATen not compiled with MKLDNN support"); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp b/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp index 9d15f74cdeb0c18..cf005e28f7b6672 100644 --- a/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp +++ b/aten/src/ATen/native/mkldnn/MkldnnTensorMath.cpp @@ -11,7 +11,7 @@ namespace at { namespace native { Tensor& mkldnn_zero_(Tensor& self) { - AT_ERROR("mkldnn_zero_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_zero_: ATen not compiled with MKLDNN support"); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index a2c6ef8704aef0a..86d9d0643a27c4e 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -17,7 +17,7 @@ std::tuple mkldnn_batch_norm( bool train, double momentum, double eps) { - AT_ERROR("mkldnn_batch_norm: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_batch_norm: ATen not compiled with MKLDNN support"); } } // namespace native @@ -49,7 +49,7 @@ std::tuple mkldnn_batch_norm( if (train) { // TODO: support training - AT_ERROR("mkldnn_batch_norm: mkldnn training is not supported in yet."); + TORCH_CHECK(false, "mkldnn_batch_norm: mkldnn training is not supported in yet."); // ideep::tensor saved_mean; // ideep::tensor saved_var; @@ -60,7 +60,7 @@ std::tuple mkldnn_batch_norm( // new_with_itensor_mkldnn(std::move(saved_mean), input.options()), // new_with_itensor_mkldnn(std::move(saved_var), input.options())); } else { - AT_ASSERTM(input.dim() == 4 || input.dim() == 5, + TORCH_CHECK(input.dim() == 4 || input.dim() == 5, "mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm"); ideep::batch_normalization_forward_inference::compute( x, m, v, w, b, y, eps); diff --git a/aten/src/ATen/native/mkldnn/Pooling.cpp b/aten/src/ATen/native/mkldnn/Pooling.cpp index 1757c2882e2a8be..a272bc3d6070b9f 100644 --- a/aten/src/ATen/native/mkldnn/Pooling.cpp +++ b/aten/src/ATen/native/mkldnn/Pooling.cpp @@ -184,6 +184,8 @@ Tensor mkldnn_max_pool2d( IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) { + TORCH_CHECK(std::all_of(dilation.cbegin(), dilation.cend(), [](int64_t i) { return 1 == i; }), + "mkldnn_max_pool2d does not support dilation case"); return _mkldnn_pooling( input, kernel_size, @@ -201,6 +203,8 @@ Tensor mkldnn_max_pool3d( IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) { + TORCH_CHECK(std::all_of(dilation.cbegin(), dilation.cend(), [](int64_t i) { return 1 == i; }), + "mkldnn_max_pool3d does not support dilation case"); return _mkldnn_pooling( input, kernel_size, @@ -220,7 +224,7 @@ Tensor mkldnn_avg_pool2d( bool count_include_pad, c10::optional divisor_override) { TORCH_CHECK(!divisor_override.has_value(), - "mkldnn_avg_pool2d operator does not support divisor"); + "mkldnn_avg_pool2d operator does not support divisor"); return _mkldnn_pooling( input, kernel_size, diff --git a/aten/src/ATen/native/mkldnn/Relu.cpp b/aten/src/ATen/native/mkldnn/Relu.cpp index b8615c0772e1c45..42397255caf0048 100644 --- a/aten/src/ATen/native/mkldnn/Relu.cpp +++ b/aten/src/ATen/native/mkldnn/Relu.cpp @@ -8,11 +8,11 @@ namespace at { namespace native { Tensor mkldnn_relu(const Tensor& input) { - AT_ERROR("mkldnn_relu: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_relu: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_relu_(Tensor& input) { - AT_ERROR("mkldnn_relu_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_relu_: ATen not compiled with MKLDNN support"); } }} diff --git a/aten/src/ATen/native/mkldnn/SoftMax.cpp b/aten/src/ATen/native/mkldnn/SoftMax.cpp index 80f94d74b965bec..cdeb6cb85971967 100644 --- a/aten/src/ATen/native/mkldnn/SoftMax.cpp +++ b/aten/src/ATen/native/mkldnn/SoftMax.cpp @@ -11,7 +11,7 @@ Tensor mkldnn_softmax( const Tensor& self, const int64_t dim, const bool half_to_float) { - AT_ERROR("mkldnn_softmax: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_softmax: ATen not compiled with MKLDNN support"); } } // namespace native @@ -28,7 +28,7 @@ Tensor mkldnn_softmax( const Tensor& self, const int64_t dim, const bool half_to_float) { - AT_ASSERTM( + TORCH_CHECK( !half_to_float, "softmax with half to float conversion is not supported on Mkldnn"); const int64_t wrapped_dim = maybe_wrap_dim(dim, self.dim()); diff --git a/aten/src/ATen/native/mkldnn/TensorFactories.cpp b/aten/src/ATen/native/mkldnn/TensorFactories.cpp index ec9a7a5429142e2..603819ed3287382 100644 --- a/aten/src/ATen/native/mkldnn/TensorFactories.cpp +++ b/aten/src/ATen/native/mkldnn/TensorFactories.cpp @@ -21,7 +21,7 @@ Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::option #else Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional optional_memory_format) { - AT_ERROR("empty_mkldnn: MKL-DNN build is disabled"); + TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled"); } #endif // AT_MKLDNN_ENABLED() diff --git a/aten/src/ATen/native/mkldnn/TensorShape.cpp b/aten/src/ATen/native/mkldnn/TensorShape.cpp index af33617a32ed4da..3229a07e94609db 100644 --- a/aten/src/ATen/native/mkldnn/TensorShape.cpp +++ b/aten/src/ATen/native/mkldnn/TensorShape.cpp @@ -9,23 +9,23 @@ namespace at { namespace native { Tensor mkldnn_view(const Tensor& self, IntArrayRef size) { - AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support"); } Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) { - AT_ERROR("mkldnn_reshape: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_reshape: ATen not compiled with MKLDNN support"); } Tensor mkldnn_clone(const Tensor& self, c10::optional optional_memory_format) { - AT_ERROR("mkldnn_clone: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_clone: ATen not compiled with MKLDNN support"); } Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { - AT_ERROR("mkldnn_transpose: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_transpose: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) { - AT_ERROR("mkldnn_transpose_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_transpose_: ATen not compiled with MKLDNN support"); } } // namespace native @@ -39,7 +39,7 @@ namespace at { namespace native { Tensor mkldnn_view(const Tensor& self, IntArrayRef size) { - AT_ERROR( + TORCH_CHECK(false, "Currently Mkldnn tensor does not support view. Change to use reshape instead"); } @@ -65,7 +65,7 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional optiona return new_with_itensor_mkldnn(std::move(dst), self.options()); } -Tensor mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1) { +Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) { const ideep::tensor& x = itensor_from_mkldnn(self); ideep::tensor y; std::vector axes(x.ndims()); @@ -76,7 +76,7 @@ Tensor mkldnn_transpose(const Tensor & self, int64_t dim0, int64_t dim1) { } Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) { - AT_ERROR("mkldnn_transpose_: in-place mkldnn operations are not supported yet"); + TORCH_CHECK(false, "mkldnn_transpose_: in-place mkldnn operations are not supported yet"); } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/UnaryOps.cpp b/aten/src/ATen/native/mkldnn/UnaryOps.cpp index 5045acd60a57aed..4eb02dc483c5ce4 100644 --- a/aten/src/ATen/native/mkldnn/UnaryOps.cpp +++ b/aten/src/ATen/native/mkldnn/UnaryOps.cpp @@ -8,11 +8,11 @@ namespace at { namespace native { Tensor mkldnn_sigmoid(const Tensor& self) { - AT_ERROR("mkldnn_sigmoid: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_sigmoid: ATen not compiled with MKLDNN support"); } Tensor& mkldnn_sigmoid_(Tensor& self) { - AT_ERROR("mkldnn_sigmoid_: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_sigmoid_: ATen not compiled with MKLDNN support"); } } // namespace native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7d55d18fe4f87d5..d4faba561329753 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -495,6 +495,17 @@ - func: asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# arcsinh, alias for asinh +- func: arcsinh(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: arcsinh_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + +- func: arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + - func: atanh(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method @@ -505,6 +516,17 @@ - func: atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) +# arctanh, alias for atanh +- func: arctanh(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: arctanh_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + +- func: arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + - func: as_strided(Tensor(a) self, int[] size, int[] stride, int? storage_offset=None) -> Tensor(a) use_c10_dispatcher: full variants: function, method @@ -1236,6 +1258,9 @@ CPU: _embedding_bag_forward_only_cpu CUDA: _embedding_bag_forward_only_cuda +- func: rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) + use_c10_dispatcher: full + - func: embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) use_c10_dispatcher: full @@ -3584,6 +3609,26 @@ use_c10_dispatcher: full variants: method +# subtract, alias for sub +- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + +- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + +# For C++ only, until we have conversion from C++ numbers to Tensor +- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + use_c10_dispatcher: full + variants: method + - func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor use_c10_dispatcher: full variants: function @@ -5727,9 +5772,16 @@ device_guard: False variants: function dispatch: - CPU: foreach_add_scalar_kernel_fallback + CPU: foreach_tensor_add_scalar_kernel_slow CUDA: foreach_tensor_add_scalar_kernel_cuda +- func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> () + device_guard: False + variants: function + dispatch: + CPU: foreach_tensor_add_scalar_kernel_slow_ + CUDA: foreach_tensor_add_scalar_kernel_cuda_ + - func: _mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor, Tensor) use_c10_dispatcher: full dispatch: @@ -7304,6 +7356,22 @@ - func: ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) +- func: linalg_norm(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.ord_str(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + python_module: linalg + variants: function + +- func: linalg_norm.out(Tensor self, Scalar? ord=None, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + +- func: linalg_norm.ord_str_out(Tensor self, str ord, int[]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + python_module: linalg + variants: function + ## Functions that are only for testing # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h new file mode 100644 index 000000000000000..1d6010490c81d79 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -0,0 +1,301 @@ +#pragma once + +#include +#include + +#include +#include + +#include + +/* Convolution prepacked parameters serialization. + * + * Version 1 + * + * - Fields: + * 1. weight + * 2. bias + * 3. stride x kSpatialDim + * 4. padding x kSpatialDim + * 5. dilation x kSpatialDim + * 6. groups + * + * Version 2 + * + * - Fields: + * 0. version (string) + * 1. list of non-optional tensors + * 0: packed parameters (int16_t) + * - kSpatialDim + * - stride x kSpatialDim + * - padding x kSpatialDim + * - dilation x kSpatialDim + * - output_padding x kSpatialDim (unused) + * - groups + * - transpose (0 or 1, unused) + * 1: weight + * 2. list of optional tensors + * 0: bias + * + * Note: version is a string and conv params are packed into a Tensor + * to make ONNX happy (ints and containers of ints are not supported). + */ + +// version 1 +using ConvParamsSerializationTypeLegacy = std::tuple< + // weight + at::Tensor, + // bias + c10::optional, + // stride x kSpatialDim + torch::List, + // padding x kSpatialDim + torch::List, + // dilation x kSpatialDim + torch::List, + // groups + at::Tensor>; + +// version 2 +using ConvParamsSerializationType = std::tuple< + // version, for versions 2 and up + std::string, + // non-optional tensors + std::vector, + // optional tensors + std::vector>>; + +// Parses any historical conv packed params format into +// the current format. +template +ConvParamsSerializationType parse_conv_serialized_state(c10::IValue v) { + + // determine the version based on IValue contents + int version = -1; + if (v.isTuple()) { + auto elements = v.toTuple()->elements(); + if (elements.size() > 0) { + auto firstElement = elements[0]; + if (firstElement.isTensor()) { + version = 1; + } else if (firstElement.isString()) { + std::string version_str = firstElement.toStringRef(); + // note: not parsing the string to automatically handle bad + // inputs + if (version_str == "2") { + version = 2; + } + } + } + } + TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version"); + + if (version == 1) { + // version 1 - convert to version 2 manually + + auto elements = v.toTuple()->elements(); + + at::Tensor weight = elements[0].toTensor(); + c10::optional bias = elements[1].toOptional(); + torch::List stride_x_kSpatialDim = elements[2].toTensorList(); + torch::List padding_x_kSpatialDim = elements[3].toTensorList(); + torch::List dilation_x_kSpatialDim = elements[4].toTensorList(); + at::Tensor groups = elements[5].toTensor(); + + std::string version = "2"; + std::vector non_optional; + std::vector> optional; + + std::vector params_vec; + params_vec.push_back(kSpatialDim); + for (int i = 0; i < stride_x_kSpatialDim.size(); i++) { + auto stride = stride_x_kSpatialDim.get(i); + params_vec.push_back(stride[0].item()); + } + for (int i = 0; i < padding_x_kSpatialDim.size(); i++) { + auto padding = padding_x_kSpatialDim.get(i); + params_vec.push_back(padding[0].item()); + } + for (int i = 0; i < dilation_x_kSpatialDim.size(); i++) { + auto dilation = dilation_x_kSpatialDim.get(i); + params_vec.push_back(dilation[0].item()); + } + // output_padding does not exist in v1, so we fill in a default value + for (int i = 0; i < kSpatialDim; i++) { + params_vec.push_back(0); + } + params_vec.push_back(groups[0].item()); + // transpose does not exist in v1, so we fill in a default value + params_vec.push_back(0); + int64_t vec_size = params_vec.size(); + at::Tensor params_tensor = at::from_blob(params_vec.data(), + {vec_size}, at::TensorOptions().dtype(at::kShort)) + // clone to retain ownership of the data + .clone(); + + non_optional.emplace_back(std::move(params_tensor)); + non_optional.emplace_back(std::move(weight)); + optional.emplace_back(std::move(bias)); + + return std::tie(version, non_optional, optional); + } else if (version == 2) { + // version 2 + return v.to(); + } else { + TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ", + version); + } +} + +template +ConvParamsSerializationType serialize_conv( + const c10::intrusive_ptr>& params) { + + std::string version = "2"; + std::vector non_optional; + std::vector> optional; + + // create a packed int8_t tensor for conv params + std::vector params_vec; + params_vec.push_back(kSpatialDim); + auto stride = params->stride().vec(); + params_vec.insert(params_vec.end(), stride.begin(), stride.end()); + auto padding = params->padding().vec(); + params_vec.insert(params_vec.end(), padding.begin(), padding.end()); + auto dilation = params->dilation().vec(); + params_vec.insert(params_vec.end(), dilation.begin(), dilation.end()); + // output_padding is not implemented yet, so we fill in a default value + for (int i = 0; i < kSpatialDim; i++) { + params_vec.push_back(0); + } + params_vec.push_back(params->groups()); + // transpose is not implemented yet, so we fill in a default value + params_vec.push_back(0); + int64_t vec_size = params_vec.size(); + at::Tensor params_tensor = at::from_blob( + params_vec.data(), {vec_size}, + at::TensorOptions().dtype(at::kShort)) + // clone to retain ownership of the data + .clone(); + + at::Tensor weight; + c10::optional bias; + std::tie(weight, bias) = params->unpack(); + + non_optional.emplace_back(std::move(params_tensor)); + non_optional.emplace_back(std::move(weight)); + optional.emplace_back(std::move(bias)); + + return std::tie(version, non_optional, optional); +} + +template +ConvParamsSerializationTypeLegacy serialize_conv_legacy( + const c10::intrusive_ptr>& params) { + at::Tensor weight; + c10::optional bias; + std::tie(weight, bias) = params->unpack(); + torch::List stride; + torch::List padding; + torch::List dilation; + at::Tensor groups; + for (int64_t s : params->stride()) { + stride.emplace_back(at::tensor(s)); + } + for (int64_t p : params->padding()) { + padding.emplace_back(at::tensor(p)); + } + for (int64_t d : params->dilation()) { + dilation.emplace_back(at::tensor(d)); + } + groups = at::tensor(params->groups()); + return std::make_tuple( + std::move(weight), + std::move(bias), + stride, + padding, + dilation, + groups); +} + +template +c10::intrusive_ptr> deserialize_conv( + ConvParamsSerializationType state) { + + std::string version; + std::vector non_optional; + std::vector> optional; + + std::tie(version, non_optional, optional) = state; + TORCH_INTERNAL_ASSERT(version == "2", "Unexpected serialized qconv version: ", + version); + + at::Tensor conv_params_packed = non_optional[0]; + at::Tensor weight = non_optional[1]; + c10::optional bias = optional[0]; + + torch::List stride, padding, dilation; + // skip kSpatialDim + int idx = 1; + for (int i = 0; i < kSpatialDim; ++i) { + stride.emplace_back(conv_params_packed[idx].item()); + idx++; + } + for (int i = 0; i < kSpatialDim; ++i) { + padding.emplace_back(conv_params_packed[idx].item()); + idx++; + } + for (int i = 0; i < kSpatialDim; ++i) { + dilation.emplace_back(conv_params_packed[idx].item()); + idx++; + } + // output_padding is not implemented yet, so we skip the entries + for (int i = 0; i < kSpatialDim; ++i) { + // do nothing + idx++; + } + int64_t groups = conv_params_packed[idx].item(); + idx++; + // transpose is not implemented yet, so we skip the entry + idx++; + TORCH_INTERNAL_ASSERT(idx == conv_params_packed.numel(), + "Unexpected length of conv_params_packed, expected ", + idx, + " got ", + conv_params_packed.numel()); + + auto& ctx = at::globalContext(); + +#ifdef USE_FBGEMM + if (ctx.qEngine() == at::QEngine::FBGEMM) { + return PackedConvWeight::prepack( + weight, + bias, + stride, + padding, + dilation, + groups + ); + } +#endif // USE_FBGEMM +#ifdef USE_PYTORCH_QNNPACK + if (ctx.qEngine() == at::QEngine::QNNPACK) { + TORCH_CHECK( + kSpatialDim == 2, + "prepack/__setstate__: QNNPACK only supports Conv2d " + "now."); + return PackedConvWeightsQnnp::prepack( + weight, + bias, + stride, + padding, + dilation, + groups + ); + } +#endif // USE_PYTORCH_QNNPACK +TORCH_CHECK( + false, + "Didn't find engine for when deserializing ConvPackedParams: ", + toString(ctx.qEngine())); +} diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 48fdc04fe159409..6b93d50104f11a5 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -1,10 +1,11 @@ -#include -#include -#include -#include - #include #include + +#include +#include +#include +#include +#include #include #include @@ -213,99 +214,27 @@ Tensor ConvertToChannelsLast3dTensor(const Tensor& src) { template CAFFE2_API torch::class_> register_conv_params() { - using SerializationType = std::tuple< - at::Tensor, - c10::optional, - // these are meant to be torch::List but - // it's not supported by onnx, so we'll use Tensor as - // a workaround - torch::List, - torch::List, - torch::List, - at::Tensor>; static auto register_conv_params = torch::class_>( "quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase") .def_pickle( + /* + [](const c10::intrusive_ptr>& params) + -> ConvParamsSerializationType { // __getstate__ + return serialize_conv(params); + }, + */ + // TODO (#43649): switch this to serialize_conv [](const c10::intrusive_ptr>& params) - -> SerializationType { // __getstate__ - at::Tensor weight; - c10::optional bias; - std::tie(weight, bias) = params->unpack(); - torch::List stride; - torch::List padding; - torch::List dilation; - at::Tensor groups; - for (int64_t s : params->stride()) { - stride.emplace_back(at::tensor(s)); - } - for (int64_t p : params->padding()) { - padding.emplace_back(at::tensor(p)); - } - for (int64_t d : params->dilation()) { - dilation.emplace_back(at::tensor(d)); - } - groups = at::tensor(params->groups()); - return std::make_tuple( - std::move(weight), - std::move(bias), - stride, - padding, - dilation, - groups); + -> ConvParamsSerializationTypeLegacy { // __getstate__ + return serialize_conv_legacy(params); }, - [](SerializationType state) + // __setstate__ takes c10::IValue because we support parsing historical + // serialization versions. + [](c10::IValue v) -> c10::intrusive_ptr> { // __setstate__ - at::Tensor weight; - c10::optional bias; - torch::List stride_tensor, padding_tensor, - dilation_tensor; - at::Tensor groups_tensor; - torch::List stride, padding, dilation; - int64_t groups; - std::tie(weight, bias, stride_tensor, padding_tensor, dilation_tensor, groups_tensor) = state; - for (at::Tensor s : stride_tensor) { - stride.emplace_back(s[0].item()); - } - for (at::Tensor p : padding_tensor) { - padding.emplace_back(p[0].item()); - } - for (at::Tensor d : dilation_tensor) { - dilation.emplace_back(d[0].item()); - } - groups = groups_tensor[0].item(); - auto& ctx = at::globalContext(); - -#ifdef USE_FBGEMM - if (ctx.qEngine() == at::QEngine::FBGEMM) { - return PackedConvWeight::prepack( - weight, - bias, - stride, - padding, - dilation, - groups); - } -#endif // USE_FBGEMM -#ifdef USE_PYTORCH_QNNPACK - if (ctx.qEngine() == at::QEngine::QNNPACK) { - TORCH_CHECK( - kSpatialDim == 2, - "prepack/__setstate__: QNNPACK only supports Conv2d " - "now."); - return PackedConvWeightsQnnp::prepack( - weight, - bias, - stride, - padding, - dilation, - groups); - } -#endif // USE_PYTORCH_QNNPACK - TORCH_CHECK( - false, - "Didn't find engine for when deserializing ConvPackedParams: ", - toString(ctx.qEngine())); + ConvParamsSerializationType state = parse_conv_serialized_state(v); + return deserialize_conv(state); }) .def("weight", [](const c10::intrusive_ptr>& self) { at::Tensor weight; diff --git a/aten/src/ATen/native/quantized/cpu/q_adaavgpool.cpp b/aten/src/ATen/native/quantized/cpu/q_adaavgpool.cpp index 13c631b6bfba4a7..cda1aa79120a757 100644 --- a/aten/src/ATen/native/quantized/cpu/q_adaavgpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/q_adaavgpool.cpp @@ -177,7 +177,8 @@ Tensor _adaptive_avg_pool(const Tensor& input, auto osizeW = output_shape[output_shape.size() - 1]; int64_t sizeB = output_shape.size() ==(kSpatialDim + 1) ? 0 : output_shape[0]; - if (input.is_contiguous(c10::MemoryFormat::ChannelsLast)) { + if (input.is_contiguous(c10::MemoryFormat::ChannelsLast) || + input.is_contiguous(c10::MemoryFormat::ChannelsLast3d)) { // Fast path for NDHWC output = at::_empty_affine_quantized( output_shape, diff --git a/aten/src/ATen/native/vulkan/api/Allocator.cpp b/aten/src/ATen/native/vulkan/api/Allocator.cpp new file mode 100644 index 000000000000000..7b153786dd83c80 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Allocator.cpp @@ -0,0 +1,2 @@ +#define VMA_IMPLEMENTATION +#include diff --git a/aten/src/ATen/native/vulkan/api/Allocator.h b/aten/src/ATen/native/vulkan/api/Allocator.h new file mode 100644 index 000000000000000..afa720a515e6748 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Allocator.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +#ifdef __clang__ + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wnullability-completeness" + #pragma clang diagnostic ignored "-Wunused-variable" +#endif + +// Do NOT include vk_mem_alloc.h directly. +// Always include this file (Allocator.h) instead. + +#include + +#ifdef __clang__ + #pragma clang diagnostic pop +#endif diff --git a/aten/src/ATen/native/vulkan/api/Cache.h b/aten/src/ATen/native/vulkan/api/Cache.h new file mode 100644 index 000000000000000..36291a2227d4fb3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Cache.h @@ -0,0 +1,88 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +// +// A generic cache for immutable Vulkan objects, when there will not be many +// instances of those objects required at runtime. The previous sentence puts +// two constraints on proper use of this cache: 1) First, the objects should +// preferably be immutable otherwise much care is required to synchronize +// their usage. 2) Second, this cache is only intended for objects that +// we will not have many instances of during the entire execution of the +// program, otherwise the cache must be _infrequently_ purged. Proper usage +// model for this cache is in direct contrast with Vulkan object pools, which +// indeed are required to be _frequently_ purged. That is an important +// distinction. +// + +template +class Cache final { + public: + explicit Cache(Factory factory); + Cache(const Cache&) = delete; + Cache& operator=(const Cache&) = delete; + Cache(Cache&&) = default; + Cache& operator=(Cache&&) = default; + ~Cache() = default; + + // Factory must have the following symbols defined. + + typedef typename Factory::Descriptor Descriptor; + typedef typename Factory::Handle Handle; + typedef typename Factory::Hasher Hasher; + + // Create or retrieve a resource. + // + // This operation is a simple cache lookup and returns the Handle corresponding + // to the descriptor if the object is already present in the cache. Otherwise, + // Factory is used to create the object, after which point the object is added + // to the cache. Regardless, this function returns with the object in the cache. + + auto retrieve(const Descriptor& descriptor); + + // Only call this function infrequently, if ever. This cache is only intended + // for immutable Vulkan objects of which a small finite instances are required + // at runtime. A good place to call this function is between model loads. + + void purge(); + + private: + struct Configuration final { + static constexpr uint32_t kReserve = 64u; + }; + + ska::flat_hash_map cache_; + Factory factory_; +}; + +template +inline Cache::Cache(Factory factory) + : factory_(std::move(factory)) { + cache_.reserve(Configuration::kReserve); +} + +template +inline auto Cache::retrieve( + const Descriptor& descriptor) { + auto iterator = cache_.find(descriptor); + if (cache_.cend() == iterator) { + iterator = cache_.insert({descriptor, factory_(descriptor)}).first; + } + + return iterator->second.get(); +} + +template +inline void Cache::purge() { + cache_.clear(); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Command.cpp b/aten/src/ATen/native/vulkan/api/Command.cpp new file mode 100644 index 000000000000000..21279b40823311a --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Command.cpp @@ -0,0 +1,107 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +Command::Pool::Factory::Factory(const VkDevice device) + : device_(device) { +} + +typename Command::Pool::Factory::Handle Command::Pool::Factory::operator()( + const Descriptor& descriptor) const { + const VkCommandPoolCreateInfo command_pool_create_info{ + VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO, + nullptr, + VK_COMMAND_POOL_CREATE_TRANSIENT_BIT, + descriptor.queue_family_index, + }; + + VkCommandPool command_pool{}; + VK_CHECK(vkCreateCommandPool( + device_, &command_pool_create_info, nullptr, &command_pool)); + + return Handle{ + command_pool, + Deleter(device_), + }; +} + +void Command::Pool::purge( + const VkDevice device, + const VkCommandPool command_pool) { + TORCH_INTERNAL_ASSERT(device, "Invalid Vulkan device!"); + TORCH_INTERNAL_ASSERT(command_pool, "Invalid Vulkan command pool!"); + + VK_CHECK(vkResetCommandPool(device, command_pool, 0u)); +} + +namespace { + +VkCommandBuffer allocate_command_buffer( + const VkDevice device, + const VkCommandPool command_pool) { + const VkCommandBufferAllocateInfo command_buffer_allocate_info{ + VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO, + nullptr, + command_pool, + VK_COMMAND_BUFFER_LEVEL_PRIMARY, + 1u, + }; + + VkCommandBuffer command_buffer{}; + VK_CHECK(vkAllocateCommandBuffers( + device, &command_buffer_allocate_info, &command_buffer)); + + return command_buffer; +} + +} // namespace + +Command::Buffer::Buffer(const VkDevice device, const VkCommandPool command_pool) + : command_buffer_(allocate_command_buffer(device, command_pool)) { +} + +void Command::Buffer::Buffer::begin() { + const VkCommandBufferBeginInfo command_buffer_begin_info{ + VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO, + nullptr, + VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT, + nullptr, + }; + + VK_CHECK(vkBeginCommandBuffer(command_buffer_, &command_buffer_begin_info)); +} + +void Command::Buffer::Buffer::end() { + VK_CHECK(vkEndCommandBuffer(command_buffer_)); +} + +void Command::Buffer::bind(const VkPipeline pipeline) { + TORCH_INTERNAL_ASSERT(pipeline, "Invalid Vulkan pipeline!"); + + vkCmdBindPipeline(command_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); +} + +void Command::Buffer::bind( + const VkPipelineLayout pipeline_layout, + const VkDescriptorSet descriptor_set) { + TORCH_INTERNAL_ASSERT(pipeline_layout, "Invalid Vulkan pipeline layout!"); + TORCH_INTERNAL_ASSERT(descriptor_set, "Invalid Vulkan descriptor set!"); + + vkCmdBindDescriptorSets( + command_buffer_, + VK_PIPELINE_BIND_POINT_COMPUTE, + pipeline_layout, + 0u, + 1u, + &descriptor_set, + 0u, + nullptr); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Command.h b/aten/src/ATen/native/vulkan/api/Command.h new file mode 100644 index 000000000000000..462a50fef7fd4a5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Command.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +struct C10_EXPORT Command final { + // + // Pool + // + + struct Pool final { + /* + Descriptor + */ + + struct Descriptor final { + uint32_t queue_family_index; + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + + typedef Pool::Descriptor Descriptor; + typedef VK_DELETER(CommandPool) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Pool(const VkDevice device) + : cache(Factory(device)) { + } + + static void purge(VkDevice device, VkCommandPool command_pool); + } pool; + + // + // Buffer + // + + class Buffer final { + public: + Buffer(VkDevice device, VkCommandPool command_pool); + + void begin(); + void end(); + + void bind(VkPipeline pipeline); + void bind(VkPipelineLayout pipeline_layout, VkDescriptorSet descriptor_set); + void dispatch(); + + private: + VkCommandBuffer command_buffer_; + }; + + explicit Command(const VkDevice device) + : pool(device) { + } +}; + +// +// Impl +// + +inline bool operator==( + const Command::Pool::Descriptor& _1, + const Command::Pool::Descriptor& _2) { + return _1.queue_family_index == _2.queue_family_index; +} + +inline size_t Command::Pool::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + return c10::get_hash(descriptor.queue_family_index); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Common.cpp b/aten/src/ATen/native/vulkan/api/Common.cpp new file mode 100644 index 000000000000000..8749d4b420e014e --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Common.cpp @@ -0,0 +1,47 @@ +#include + +#define VK_DELETER_DISPATCHABLE_DEFINE(Handle) \ + VK_DELETER_DISPATCHABLE_DECLARE(Handle) { \ + if (C10_LIKELY(VK_NULL_HANDLE != handle)) { \ + vkDestroy##Handle(handle, nullptr); \ + } \ + } + +#define VK_DELETER_NON_DISPATCHABLE_DEFINE(Handle) \ + destroy_##Handle::destroy_##Handle(const VkDevice device) \ + : device_(device) { \ + } \ + \ + void destroy_##Handle::operator()(const Vk##Handle handle) const { \ + if (C10_LIKELY(VK_NULL_HANDLE != handle)) { \ + vkDestroy##Handle(device_, handle, nullptr); \ + } \ + } + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +VK_DELETER_DISPATCHABLE_DEFINE(Instance); +VK_DELETER_DISPATCHABLE_DEFINE(Device); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Semaphore); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Fence); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Buffer); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Image); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Event); +VK_DELETER_NON_DISPATCHABLE_DEFINE(BufferView); +VK_DELETER_NON_DISPATCHABLE_DEFINE(ImageView); +VK_DELETER_NON_DISPATCHABLE_DEFINE(ShaderModule); +VK_DELETER_NON_DISPATCHABLE_DEFINE(PipelineCache); +VK_DELETER_NON_DISPATCHABLE_DEFINE(PipelineLayout); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Pipeline); +VK_DELETER_NON_DISPATCHABLE_DEFINE(DescriptorSetLayout); +VK_DELETER_NON_DISPATCHABLE_DEFINE(Sampler); +VK_DELETER_NON_DISPATCHABLE_DEFINE(DescriptorPool); +VK_DELETER_NON_DISPATCHABLE_DEFINE(CommandPool); + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Common.h b/aten/src/ATen/native/vulkan/api/Common.h new file mode 100644 index 000000000000000..0c1e7cc4720bd7f --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Common.h @@ -0,0 +1,158 @@ +#pragma once + +#include + +#ifdef USE_VULKAN_WRAPPER +#include +#else +#include +#endif + +#define VK_CHECK(function) \ + { \ + const VkResult result = (function); \ + TORCH_CHECK(VK_SUCCESS == result, "VkResult:", result); \ + } + +#define VK_CHECK_RELAXED(function) \ + { \ + const VkResult result = (function); \ + TORCH_CHECK(VK_SUCCESS <= result, "VkResult:", result); \ + } + +#define VK_DELETER(Handle) \ + at::native::vulkan::api::destroy_##Handle + +#define VK_DELETER_DISPATCHABLE_DECLARE(Handle) \ + C10_EXPORT void destroy_##Handle(const Vk##Handle handle) + +#define VK_DELETER_NON_DISPATCHABLE_DECLARE(Handle) \ + class C10_EXPORT destroy_##Handle final { \ + public: \ + explicit destroy_##Handle(const VkDevice device); \ + void operator()(const Vk##Handle handle) const; \ + private: \ + VkDevice device_; \ + }; + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +VK_DELETER_DISPATCHABLE_DECLARE(Instance); +VK_DELETER_DISPATCHABLE_DECLARE(Device); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Semaphore); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Fence); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Buffer); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Image); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Event); +VK_DELETER_NON_DISPATCHABLE_DECLARE(BufferView); +VK_DELETER_NON_DISPATCHABLE_DECLARE(ImageView); +VK_DELETER_NON_DISPATCHABLE_DECLARE(ShaderModule); +VK_DELETER_NON_DISPATCHABLE_DECLARE(PipelineCache); +VK_DELETER_NON_DISPATCHABLE_DECLARE(PipelineLayout); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Pipeline); +VK_DELETER_NON_DISPATCHABLE_DECLARE(DescriptorSetLayout); +VK_DELETER_NON_DISPATCHABLE_DECLARE(Sampler); +VK_DELETER_NON_DISPATCHABLE_DECLARE(DescriptorPool); +VK_DELETER_NON_DISPATCHABLE_DECLARE(CommandPool); + +// Vulkan objects are referenced via handles. The spec defines Vulkan handles +// under two categories: dispatchable and non-dispatchable. Dispatchable handles +// are required to be strongly typed as a result of being pointers to unique +// opaque types. Since dispatchable handles are pointers at the heart, +// std::unique_ptr can be used to manage their lifetime with a custom deleter. +// Non-dispatchable handles on the other hand, are not required to have strong +// types, and even though they default to the same implementation as dispatchable +// handles on some platforms - making the use of std::unique_ptr possible - they +// are only required by the spec to weakly aliases 64-bit integers which is the +// implementation some platforms default to. This makes the use of std::unique_ptr +// difficult since semantically unique_ptrs store pointers to their payload +// which is also what is passed onto the custom deleters. + +template +class Handle final { + public: + Handle(Type payload, Deleter deleter); + Handle(const Handle&) = delete; + Handle& operator=(const Handle&) = delete; + Handle(Handle&&); + Handle& operator=(Handle&&); + ~Handle(); + + operator bool() const; + Type get() const; + Type release(); + void reset(Type payload = kNull); + + private: + static constexpr Type kNull{}; + + private: + Type payload_; + Deleter deleter_; +}; + +// +// Impl +// + +template +inline Handle::Handle(const Type payload, Deleter deleter) + : payload_(payload), + deleter_(std::move(deleter)) { +} + +template +inline Handle::Handle(Handle&& handle) + : payload_(handle.release()), + deleter_(std::move(handle.deleter_)) { +} + +template +inline Handle& +Handle::operator=(Handle&& handle) +{ + reset(handle.release()); + deleter_ = std::move(handle.deleter_); + return *this; +} + +template +inline Handle::~Handle() { + reset(); +} + +template +inline Handle::operator bool() const { + return get(); +} + +template +inline Type Handle::get() const { + return payload_; +} + +template +inline Type Handle::release() { + const Type payload = payload_; + payload_ = kNull; + + return payload; +} + +template +inline void Handle::reset(Type payload) { + using std::swap; + swap(payload_, payload); + + if (kNull != payload) { + deleter_(payload); + } +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Context.cpp b/aten/src/ATen/native/vulkan/api/Context.cpp new file mode 100644 index 000000000000000..76a245e16d38993 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Context.cpp @@ -0,0 +1,322 @@ +#include + +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { +namespace { + +struct Configuration final { +#ifndef DEBUG + static constexpr bool kEnableValidationLayers = false; +#else + static constexpr bool kEnableValidationLayers = true; +#endif +}; + +VKAPI_ATTR VkBool32 VKAPI_CALL debug_report_callback_fn( + const VkDebugReportFlagsEXT flags, + const VkDebugReportObjectTypeEXT /* object_type */, + const uint64_t /* object */, + const size_t /* location */, + const int32_t message_code, + const char* const layer_prefix, + const char* const message, + void* const /* user_data */) { + std::stringstream stream; + stream << layer_prefix << " " << message_code << " " << message << std::endl; + const std::string log = stream.str(); + + if (flags & VK_DEBUG_REPORT_ERROR_BIT_EXT) { + LOG(ERROR) << log; + } else if (flags & VK_DEBUG_REPORT_WARNING_BIT_EXT) { + LOG(WARNING) << log; + } else if (flags & VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT) { + LOG(WARNING) << "Performance:" << log; + } else if (flags & VK_DEBUG_REPORT_INFORMATION_BIT_EXT) { + LOG(INFO) << log; + } else if (flags & VK_DEBUG_REPORT_DEBUG_BIT_EXT) { + LOG(INFO) << "Debug: " << log; + } + + return VK_FALSE; +} + +VkInstance create_instance(const bool enable_validation_layers) { + std::vector enabled_instance_layers; + std::vector enabled_instance_extensions; + + if (enable_validation_layers) { + uint32_t instance_layers_count = 0; + VK_CHECK(vkEnumerateInstanceLayerProperties( + &instance_layers_count, nullptr)); + + std::vector instance_layer_properties( + instance_layers_count); + + VK_CHECK(vkEnumerateInstanceLayerProperties( + &instance_layers_count, + instance_layer_properties.data())); + + constexpr const char* const requested_instance_layers[]{ + // "VK_LAYER_LUNARG_api_dump", + "VK_LAYER_KHRONOS_validation", + }; + + for (const auto& requested_instance_layer : requested_instance_layers) { + for (const auto& layer : instance_layer_properties) { + if (strcmp(requested_instance_layer, layer.layerName) == 0) { + enabled_instance_layers.push_back(requested_instance_layer); + break; + } + } + } + + uint32_t instance_extension_count = 0; + VK_CHECK(vkEnumerateInstanceExtensionProperties( + nullptr, &instance_extension_count, nullptr)); + + std::vector instance_extension_properties( + instance_extension_count); + + VK_CHECK(vkEnumerateInstanceExtensionProperties( + nullptr, &instance_extension_count, instance_extension_properties.data())); + + constexpr const char* const requested_instance_extensions[]{ + VK_EXT_DEBUG_REPORT_EXTENSION_NAME, + }; + + for (const auto& requested_instance_extension : requested_instance_extensions) { + for (const auto& extension : instance_extension_properties) { + if (strcmp(requested_instance_extension, extension.extensionName) == 0) { + enabled_instance_extensions.push_back(requested_instance_extension); + break; + } + } + } + } + + constexpr VkApplicationInfo application_info{ + VK_STRUCTURE_TYPE_APPLICATION_INFO, + nullptr, + "PyTorch", + 0, + "PyTorch", + 0, + VK_API_VERSION_1_0, + }; + + const VkInstanceCreateInfo instance_create_info{ + VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO, + nullptr, + 0u, + &application_info, + static_cast(enabled_instance_layers.size()), + enabled_instance_layers.data(), + static_cast(enabled_instance_extensions.size()), + enabled_instance_extensions.data(), + }; + + VkInstance instance{}; + VK_CHECK(vkCreateInstance(&instance_create_info, nullptr, &instance)); + + return instance; +} + +VkDebugReportCallbackEXT create_debug_report_callback( + const VkInstance instance, + const bool enable_validation_layers) { + if (!enable_validation_layers) { + return VkDebugReportCallbackEXT{}; + } + + const VkDebugReportCallbackCreateInfoEXT debugReportCallbackCreateInfo{ + VK_STRUCTURE_TYPE_DEBUG_REPORT_CALLBACK_CREATE_INFO_EXT, + nullptr, + VK_DEBUG_REPORT_INFORMATION_BIT_EXT | + VK_DEBUG_REPORT_WARNING_BIT_EXT | + VK_DEBUG_REPORT_PERFORMANCE_WARNING_BIT_EXT | + VK_DEBUG_REPORT_ERROR_BIT_EXT | + VK_DEBUG_REPORT_DEBUG_BIT_EXT, + debug_report_callback_fn, + nullptr, + }; + + const auto vkCreateDebugReportCallbackEXT = + (PFN_vkCreateDebugReportCallbackEXT)vkGetInstanceProcAddr( + instance, "vkCreateDebugReportCallbackEXT"); + + TORCH_CHECK( + vkCreateDebugReportCallbackEXT, + "Could not load vkCreateDebugReportCallbackEXT"); + + VkDebugReportCallbackEXT debug_report_callback{}; + VK_CHECK(vkCreateDebugReportCallbackEXT( + instance, + &debugReportCallbackCreateInfo, + nullptr, + &debug_report_callback)); + + return debug_report_callback; +} + +VkPhysicalDevice acquire_physical_device(const VkInstance instance) { + uint32_t device_count = 0; + VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, nullptr)); + TORCH_CHECK(device_count > 0, "Vulkan: Could not find a device with Vulkan support!"); + + std::vector devices(device_count); + VK_CHECK(vkEnumeratePhysicalDevices(instance, &device_count, devices.data())); + + return devices[0]; +} + +VkPhysicalDeviceLimits query_physical_device_physical_device_limits( + const VkPhysicalDevice physical_device) { + VkPhysicalDeviceProperties physical_device_properties{}; + vkGetPhysicalDeviceProperties(physical_device, &physical_device_properties); + return physical_device_properties.limits; +} + +uint32_t query_compute_queue_family_index(const VkPhysicalDevice physical_device) { + uint32_t queue_family_count = 0; + + vkGetPhysicalDeviceQueueFamilyProperties( + physical_device, &queue_family_count, nullptr); + + TORCH_CHECK( + queue_family_count > 0, "Vulkan: Invalid number of queue families!"); + + std::vector queue_families_properties( + queue_family_count); + + vkGetPhysicalDeviceQueueFamilyProperties( + physical_device, &queue_family_count, queue_families_properties.data()); + + for (uint32_t i = 0; i < queue_families_properties.size(); ++i) { + const VkQueueFamilyProperties& properties = queue_families_properties[i]; + if (properties.queueCount > 0 && (properties.queueFlags & VK_QUEUE_COMPUTE_BIT)) { + return i; + } + } + + TORCH_CHECK( + false, + "Vulkan: Could not find a queue family that supports compute operations!"); +} + +VkDevice create_device( + const VkPhysicalDevice physical_device, + const uint32_t compute_queue_family_index) { + const float queue_priorities = 1.0f; + const VkDeviceQueueCreateInfo device_queue_create_info{ + VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO, + nullptr, + 0u, + compute_queue_family_index, + 1u, + &queue_priorities, + }; + + const VkDeviceCreateInfo device_create_info{ + VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO, + nullptr, + 0u, + 1u, + &device_queue_create_info, + 0u, + nullptr, + 0u, + nullptr, + }; + + VkDevice device{}; + VK_CHECK(vkCreateDevice(physical_device, &device_create_info, nullptr, &device)); + + return device; +} + +VkQueue acquire_queue( + const VkDevice device, + const uint32_t compute_queue_family_index) { + VkQueue queue{}; + vkGetDeviceQueue(device, compute_queue_family_index, 0, &queue); + return queue; +} + +} // namespace + +Context::Context(const bool enable_validation_layers) + : instance_(create_instance(enable_validation_layers), &VK_DELETER(Instance)), + debug_report_callback_( + create_debug_report_callback(instance(), enable_validation_layers), + Debug(instance())), + physical_device_(acquire_physical_device(instance())), + physical_device_limits_(query_physical_device_physical_device_limits(physical_device())), + compute_queue_family_index_(query_compute_queue_family_index(physical_device())), + device_(create_device(physical_device(), compute_queue_family_index_), &VK_DELETER(Device)), + queue_(acquire_queue(device(), compute_queue_family_index_)), + command_(device()), + shader_(device()), + pipeline_(device()), + descriptor_(device()), + resource_(instance(), physical_device(), device()) { +} + +Context::Debug::Debug(const VkInstance instance) + : instance_(instance) { +} + +void Context::Debug::operator()( + const VkDebugReportCallbackEXT debug_report_callback) const { + if (debug_report_callback) { + const auto vkDestroyDebugReportCallbackEXT = + (PFN_vkDestroyDebugReportCallbackEXT)vkGetInstanceProcAddr( + instance_, "vkDestroyDebugReportCallbackEXT"); + + TORCH_CHECK( + vkDestroyDebugReportCallbackEXT, + "Could not load vkDestroyDebugReportCallbackEXT"); + + vkDestroyDebugReportCallbackEXT( + instance_, debug_report_callback, nullptr); + } +} + +Context* initialize() { + static const std::unique_ptr context([]() -> Context* { +#ifdef USE_VULKAN_WRAPPER + if (!InitVulkan()) { + TORCH_WARN("Vulkan: Wrapper Failed to InitVulkan"); + return nullptr; + } +#endif + + try { + return new Context(Configuration::kEnableValidationLayers); + } + catch (...) { + return nullptr; + } + }()); + + return context.get(); +} + +bool available() { + return initialize(); +} + +Context& context() { + Context* const context = initialize(); + TORCH_CHECK(context, "Vulkan: Backend not available on this platform!"); + + return *context; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Context.h b/aten/src/ATen/native/vulkan/api/Context.h new file mode 100644 index 000000000000000..d57eab66108e14b --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Context.h @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +// +// Vulkan Context holds onto all relevant Vulkan state as it pertains to our +// use of Vulkan in PyTorch. The context is currently a global object, but +// technically it does not need to be if we were to make it explicit to the +// user. +// + +class C10_EXPORT Context final { + public: + explicit Context(bool enable_validation_layers); + ~Context() = default; + + inline VkInstance instance() const { + return instance_.get(); + } + + inline VkPhysicalDevice physical_device() const { + return physical_device_; + } + + inline const VkPhysicalDeviceLimits& physical_device_limits() const { + return physical_device_limits_; + } + + inline VkDevice device() const { + return device_.get(); + } + + inline VkQueue queue() const { + return queue_; + } + + inline Command& command() { + return command_; + } + + inline Shader& shader() { + return shader_; + } + + inline Pipeline& pipeline() { + return pipeline_; + } + + inline Descriptor& descriptor() { + return descriptor_; + } + + inline Resource& resource() { + return resource_; + } + + private: + class Debug final { + public: + explicit Debug(VkInstance instance); + void operator()(VkDebugReportCallbackEXT debug_report_callback) const; + + private: + VkInstance instance_; + }; + + private: + // Construction and destruction order matters. Do not move members around. + Handle instance_; + Handle debug_report_callback_; + VkPhysicalDevice physical_device_; + VkPhysicalDeviceLimits physical_device_limits_; + uint32_t compute_queue_family_index_; + Handle device_; + VkQueue queue_; + Command command_; + Shader shader_; + Pipeline pipeline_; + Descriptor descriptor_; + Resource resource_; +}; + +C10_EXPORT bool available(); +C10_EXPORT Context& context(); + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.cpp b/aten/src/ATen/native/vulkan/api/Descriptor.cpp new file mode 100644 index 000000000000000..1b5ea94341a36f9 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Descriptor.cpp @@ -0,0 +1,107 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +const Descriptor::Pool::Descriptor Descriptor::Pool::kDefault{ + 1024u, + { + // Note: It is OK for the sum of descriptors per type, below, to exceed + // the max total figure above, but be concenious of memory consumption. + // Considering how the descriptor pool must be frequently purged anyway + // as a result of the impracticality of having enormous pools that + // persist through the execution of the program, there is diminishing + // return in increasing max counts. + { + /* + Buffers + */ + + { + VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + 256u, + }, + { + VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + 256u, + }, + + /* + Images + */ + + { + VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, + 256u, + }, + { + VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + 256u, + }, + }, + }, +}; + +Descriptor::Pool::Factory::Factory(const VkDevice device) + : device_(device) { +} + +typename Descriptor::Pool::Factory::Handle Descriptor::Pool::Factory::operator()( + const Descriptor& descriptor) const { + const VkDescriptorPoolCreateInfo descriptor_pool_create_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO, + nullptr, + 0u, /* Do not use VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT */ + descriptor.capacity, + static_cast(descriptor.sizes.size()), + descriptor.sizes.data(), + }; + + VkDescriptorPool descriptor_pool{}; + VK_CHECK(vkCreateDescriptorPool( + device_, &descriptor_pool_create_info, nullptr, &descriptor_pool)); + + return Handle{ + descriptor_pool, + Deleter(device_), + }; +} + +void Descriptor::Pool::purge( + const VkDevice device, + const VkDescriptorPool descriptor_pool) { + VK_CHECK(vkResetDescriptorPool(device, descriptor_pool, 0u)); +} + +Descriptor::Factory::Factory(const VkDevice device, const VkDescriptorPool descriptor_pool) + : device_(device), + descriptor_pool_(descriptor_pool) { +} + +VkDescriptorSet Descriptor::Factory::allocate( + const VkDescriptorSetLayout descriptor_set_layout) { + const VkDescriptorSetAllocateInfo descriptor_set_allocate_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO, + nullptr, + descriptor_pool_, + 1u, + &descriptor_set_layout, + }; + + VkDescriptorSet descriptor_set{}; + VK_CHECK(vkAllocateDescriptorSets( + device_, &descriptor_set_allocate_info, &descriptor_set)); + + return descriptor_set; +} + +void Descriptor::Factory::purge() { + Pool::purge(device_, descriptor_pool_); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Descriptor.h b/aten/src/ATen/native/vulkan/api/Descriptor.h new file mode 100644 index 000000000000000..3e339ae4641fe8c --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Descriptor.h @@ -0,0 +1,163 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +// +// This struct defines caches of descriptor pools, and descriptor sets allocated +// from those pools, intended to minimize redundant object reconstructions or +// accelerate unavoidable memory allocations, both at the cost of extra memory +// consumption. +// +// A descriptor set is logically an array of descriptors, each of which +// references a resource (i.e. buffers and images), in turn telling the core +// executing the shader, where in GPU, or GPU-accessible system, memory the said +// resource resides. +// +// To accelerate creation of the descriptor sets, modern graphics APIs allocate +// them from a pool, more elaborately referred to as descriptor pools, which do +// need to be purged frequently _after_ none of the descriptors the pools contain +// is in use by the GPU. Care must be taken that descriptors are not freed while +// they are in use by the pipeline, which considering the asynchronous nature of +// CPU-GPU interactions, can be anytime after the command is issued until it is +// fully executed by the GPU. +// +// As you can imagine, it is possible to have multiple descriptor pools, each of +// which is configured to house different types of descriptor sets with different +// allocation strategies. These descriptor pools themselves are fairly stable +// objects in that they theymself should not be created and destroyed frequently. +// That is the reason why we store them in a cache, which according to our usage +// of the term 'cache' in this implementatoin, is reserved for objects that are +// created infrequently and stabilize to a manageable number quickly over the +// lifetime of the program. +// +// Descriptor sets though, on the other hand, are allocated from pools which +// indeed does mean that the pools must be purged on a regular basis or else +// they will run out of free items. Again, this is in line with our usage of +// the term 'pool' in this implementation which we use to refer to a container +// of objects that is allocated out of and is required to be frequently purged. +// +// It is important to point out that for performance reasons, we intentionally +// do not free the descriptor sets individually, and instead opt to purge the +// pool in its totality, even though Vulkan supports the former usage pattern +// as well. This behavior is by design. +// + +struct C10_EXPORT Descriptor final { + // + // Pool + // + + struct Pool final { + /* + Descriptor + */ + + struct Descriptor final { + uint32_t capacity; + c10::SmallVector sizes; + }; + + static const Descriptor kDefault; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + + typedef Pool::Descriptor Descriptor; + typedef VK_DELETER(DescriptorPool) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Pool(const VkDevice device) + : cache(Factory(device)) { + } + + static void purge(VkDevice device, VkDescriptorPool descriptor_pool); + } pool; + + /* + Factory + */ + + class Factory final { + public: + Factory(VkDevice device, VkDescriptorPool descriptor_pool); + + VkDescriptorSet allocate(VkDescriptorSetLayout descriptor_set_layout); + void purge(); + + private: + VkDevice device_; + VkDescriptorPool descriptor_pool_; + } factory; + + explicit Descriptor(const VkDevice device) + : pool(device), + factory(device, pool.cache.retrieve(Pool::kDefault)) { + } +}; + +// +// Impl +// + +inline bool operator==( + const Descriptor::Pool::Descriptor& _1, + const Descriptor::Pool::Descriptor& _2) { + return (_1.capacity == _2.capacity) && + (_1.sizes == _2.sizes); +} + +inline size_t Descriptor::Pool::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + size_t hash = c10::get_hash(descriptor.capacity); + + for (const VkDescriptorPoolSize& descriptor_pool_size : descriptor.sizes) { + hash = c10::hash_combine( + hash, + c10::get_hash( + descriptor_pool_size.type, + descriptor_pool_size.descriptorCount)); + } + + return hash; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at + +inline bool operator==( + const VkDescriptorPoolSize& descriptor_pool_size_1, + const VkDescriptorPoolSize& descriptor_pool_size_2) { + return (descriptor_pool_size_1.type == descriptor_pool_size_2.type) && + (descriptor_pool_size_1.descriptorCount == descriptor_pool_size_2.descriptorCount); +} diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.cpp b/aten/src/ATen/native/vulkan/api/Pipeline.cpp new file mode 100644 index 000000000000000..303eea7cb4012c5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Pipeline.cpp @@ -0,0 +1,127 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +Pipeline::Layout::Factory::Factory(const VkDevice device) + : device_(device) { +} + +typename Pipeline::Layout::Factory::Handle Pipeline::Layout::Factory::operator()( + const Descriptor& descriptor) const { + const VkPipelineLayoutCreateInfo pipeline_layout_create_info{ + VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO, + nullptr, + 0u, + 1u, + &descriptor.descriptor_set_layout, + 0u, + nullptr, + }; + + VkPipelineLayout pipeline_layout{}; + VK_CHECK(vkCreatePipelineLayout( + device_, &pipeline_layout_create_info, nullptr, &pipeline_layout)); + + return Handle{ + pipeline_layout, + Deleter(device_), + }; +} + +namespace { + +VkPipelineCache create_pipeline_cache(const VkDevice device) { + const VkPipelineCacheCreateInfo pipeline_cache_create_info{ + VK_STRUCTURE_TYPE_PIPELINE_CACHE_CREATE_INFO, + nullptr, + 0u, + 0u, + nullptr, + }; + + VkPipelineCache pipeline_cache{}; + VK_CHECK(vkCreatePipelineCache( + device, &pipeline_cache_create_info, nullptr, &pipeline_cache)); + + return pipeline_cache; +} + +} // namespace + +Pipeline::Factory::Factory(const VkDevice device) + : device_(device), + pipeline_cache_(create_pipeline_cache(device), VK_DELETER(PipelineCache)(device)) { +} + +typename Pipeline::Factory::Handle Pipeline::Factory::operator()( + const Descriptor& descriptor) const { + constexpr uint32_t x_offset = 0u; + constexpr uint32_t x_size = sizeof(Shader::WorkGroup::x); + constexpr uint32_t y_offset = x_offset + x_size; + constexpr uint32_t y_size = sizeof(Shader::WorkGroup::y); + constexpr uint32_t z_offset = y_offset + y_size; + constexpr uint32_t z_size = sizeof(Shader::WorkGroup::z); + + constexpr VkSpecializationMapEntry specialization_map_entires[3]{ + // X + { + 1u, + x_offset, + x_size, + }, + // Y + { + 2u, + y_offset, + y_size, + }, + // Z + { + 3u, + z_offset, + z_size, + }, + }; + + const VkSpecializationInfo specialization_info{ + 3u, + specialization_map_entires, + sizeof(Shader::WorkGroup), + &descriptor.work_group, + }; + + const VkComputePipelineCreateInfo compute_pipeline_create_info{ + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, + nullptr, + 0u, + VkPipelineShaderStageCreateInfo{ + VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, + nullptr, + 0u, + VK_SHADER_STAGE_COMPUTE_BIT, + descriptor.shader_module, + "main", + &specialization_info, + }, + descriptor.pipeline_layout, + VK_NULL_HANDLE, + 0u, + }; + + VkPipeline pipeline{}; + VK_CHECK(vkCreateComputePipelines( + device_, pipeline_cache_.get(), 1u, &compute_pipeline_create_info, nullptr, &pipeline)); + + return Handle{ + pipeline, + Deleter(device_), + }; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Pipeline.h b/aten/src/ATen/native/vulkan/api/Pipeline.h new file mode 100644 index 000000000000000..a5d72324c36e574 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Pipeline.h @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +// +// This struct defines pipeline, and pipeline layout, caches intended to minimize +// redundant object reconstructions at the cost of extra memory consumption. +// +// A Vulkan pipeline contains the entirety of states, as one coherent monolithic +// bundle, required to configure the GPU's execution pipeline. This usage +// pattern minimizes driver overhead, promotes pipeline state reuse, and is a +// departure from, and in direct contrast with, OpenGL's individually confiurable +// state machine. +// +// A Vulkan pipeline layout represents a sequence of Vulkan descriptor sets each +// having a specific layout, and deterimines the interface between all shader +// stages and shader resources. For more information on shaders and shader +// layouts check the description of at::navie::vulkan::api::Shader. +// +// This struct defines the facilities required to create, reuse, and destruct +// these Vulkan objects. +// + +struct C10_EXPORT Pipeline final { + // + // Layout + // + + struct Layout final { + /* + Descriptor + */ + + struct Descriptor final { + VkDescriptorSetLayout descriptor_set_layout; + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + + typedef Layout::Descriptor Descriptor; + typedef VK_DELETER(PipelineLayout) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Layout(const VkDevice device) + : cache(Factory(device)) { + } + } layout; + + /* + Descriptor + */ + + struct Descriptor final { + VkPipelineLayout pipeline_layout; + VkShaderModule shader_module; + Shader::WorkGroup work_group; + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + + typedef Pipeline::Descriptor Descriptor; + typedef VK_DELETER(Pipeline) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + api::Handle pipeline_cache_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Pipeline(const VkDevice device) + : layout(device), + cache(Factory(device)) { + } +}; + +// +// Impl +// + +inline bool operator==( + const Pipeline::Layout::Descriptor& _1, + const Pipeline::Layout::Descriptor& _2) { + return (_1.descriptor_set_layout == _2.descriptor_set_layout); +} + +inline size_t Pipeline::Layout::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + return c10::get_hash(descriptor.descriptor_set_layout); +} + +inline bool operator==( + const Pipeline::Descriptor& _1, + const Pipeline::Descriptor& _2) { + return (_1.pipeline_layout == _2.pipeline_layout) && + (_1.shader_module == _2.shader_module) && + (_1.work_group == _2.work_group); +} + +inline size_t Pipeline::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + return c10::get_hash( + descriptor.pipeline_layout, + descriptor.shader_module, + descriptor.work_group.x, + descriptor.work_group.y, + descriptor.work_group.z); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Resource.cpp b/aten/src/ATen/native/vulkan/api/Resource.cpp new file mode 100644 index 000000000000000..c538a1b6e2d08da --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Resource.cpp @@ -0,0 +1,241 @@ +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { +namespace { + +VmaAllocator create_allocator( + const VkInstance instance, + const VkPhysicalDevice physical_device, + const VkDevice device) { + const VmaAllocatorCreateInfo allocator_create_info{ + 0u, + physical_device, + device, + 0u, + nullptr, + nullptr, + 1u, + nullptr, + nullptr, // TODO (Ashkan): VULKAN_WRAPPER + nullptr, + instance, + VK_API_VERSION_1_0, + }; + + VmaAllocator allocator{}; + VK_CHECK(vmaCreateAllocator(&allocator_create_info, &allocator)); + + return allocator; +} + +VmaAllocationCreateInfo create_allocation_create_info( + const VmaMemoryUsage usage) { + return VmaAllocationCreateInfo{ + 0u, /* VMA_ALLOCATION_CREATE_MAPPED_BIT - MoltenVK Issue #175 */ + /* VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT */ + usage, + 0u, + 0u, + 0u, + VK_NULL_HANDLE, + nullptr, + }; +} + +void release_buffer(const Resource::Buffer& buffer) { + vmaDestroyBuffer( + buffer.memory.allocator, + buffer.handle, + buffer.memory.allocation); +} + +void release_image(const Resource::Image& image) { + if (VK_NULL_HANDLE != image.view) { + VmaAllocatorInfo allocator_info{}; + vmaGetAllocatorInfo(image.memory.allocator, &allocator_info); + vkDestroyImageView(allocator_info.device, image.view, nullptr); + } + + vmaDestroyImage( + image.memory.allocator, + image.handle, + image.memory.allocation); +} + +} // namespace + +void* map(const Resource::Memory& memory) { + // Call will be ignored by implementation if the memory type this allocation + // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior + // we want. + VK_CHECK(vmaInvalidateAllocation( + memory.allocator, memory.allocation, 0u, VK_WHOLE_SIZE)); + + void* data = nullptr; + VK_CHECK(vmaMapMemory(memory.allocator, memory.allocation, &data)); + + return data; +} + +Resource::Memory::Scope::Scope( + const VmaAllocator allocator, + const VmaAllocation allocation, + const Access access) + : allocator_(allocator), + allocation_(allocation), + access_(access) { +} + +void Resource::Memory::Scope::operator()(const void* const data) const { + if (C10_UNLIKELY(!data)) { + return; + } + + vmaUnmapMemory(allocator_, allocation_); + + if (Access::Write == access_) { + // Call will be ignored by implementation if the memory type this allocation + // belongs to is not HOST_VISIBLE or is HOST_COHERENT, which is the behavior + // we want. + VK_CHECK(vmaFlushAllocation(allocator_, allocation_, 0u, VK_WHOLE_SIZE)); + } +} + +Resource::Pool::Pool( + const VkInstance instance, + const VkPhysicalDevice physical_device, + const VkDevice device) + : device_(device), + allocator_(create_allocator(instance, physical_device, device), vmaDestroyAllocator) { + buffers_.reserve(Configuration::kReserve); + images_.reserve(Configuration::kReserve); +} + +Resource::Buffer Resource::Pool::allocate(const Buffer::Descriptor& descriptor) { + const VkBufferCreateInfo buffer_create_info{ + VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, + nullptr, + 0u, + descriptor.size, + descriptor.usage.buffer, + VK_SHARING_MODE_EXCLUSIVE, + 0u, + nullptr, + }; + + const VmaAllocationCreateInfo allocation_create_info = + create_allocation_create_info(descriptor.usage.memory); + + VkBuffer buffer{}; + VmaAllocation allocation{}; + VmaAllocationInfo allocation_info{}; + + VK_CHECK(vmaCreateBuffer( + allocator_.get(), + &buffer_create_info, + &allocation_create_info, + &buffer, + &allocation, + &allocation_info)); + + buffers_.emplace_back( + Buffer{ + buffer, + Memory{ + allocator_.get(), + allocation, + allocation_info, + }, + }, + &release_buffer); + + return buffers_.back().get(); +} + +Resource::Image Resource::Pool::allocate(const Image::Descriptor& descriptor) { + const VkImageCreateInfo image_create_info{ + VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, + nullptr, + 0u, + descriptor.type, + descriptor.format, + descriptor.extent, + 1u, + 1u, + VK_SAMPLE_COUNT_1_BIT, + VK_IMAGE_TILING_OPTIMAL, + descriptor.usage.image, + VK_SHARING_MODE_EXCLUSIVE, + 0u, + nullptr, + VK_IMAGE_LAYOUT_UNDEFINED, + }; + + const VmaAllocationCreateInfo allocation_create_info = + create_allocation_create_info(descriptor.usage.memory); + + VkImage image{}; + VmaAllocation allocation{}; + VmaAllocationInfo allocation_info{}; + + VK_CHECK(vmaCreateImage( + allocator_.get(), + &image_create_info, + &allocation_create_info, + &image, + &allocation, + &allocation_info)); + + const VkImageViewCreateInfo image_view_create_info{ + VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO, + nullptr, + 0u, + image, + descriptor.view.type, + descriptor.view.format, + { + VK_COMPONENT_SWIZZLE_IDENTITY, + VK_COMPONENT_SWIZZLE_IDENTITY, + VK_COMPONENT_SWIZZLE_IDENTITY, + VK_COMPONENT_SWIZZLE_IDENTITY, + }, + { + VK_IMAGE_ASPECT_COLOR_BIT, + 0u, + 1u, + 0u, + 1u, + }, + }; + + VkImageView view{}; + VK_CHECK(vkCreateImageView( + device_, &image_view_create_info, nullptr, &view)) + + images_.emplace_back( + Image{ + image, + view, + Memory{ + allocator_.get(), + allocation, + allocation_info, + }, + }, + &release_image); + + return images_.back().get(); +} + +void Resource::Pool::purge() { + images_.clear(); + buffers_.clear(); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Resource.h b/aten/src/ATen/native/vulkan/api/Resource.h new file mode 100644 index 000000000000000..04cd9a067663078 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Resource.h @@ -0,0 +1,177 @@ +#pragma once + +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +struct C10_EXPORT Resource final { + /* + Memory + */ + + struct Memory final { + VmaAllocator allocator; + VmaAllocation allocation; + VmaAllocationInfo allocation_info; + + class Scope; + template + using Data = Handle; + + template< + typename Type, + typename Pointer = std::add_pointer_t>> + Data map() const; + + template< + typename Type, + typename Pointer = std::add_pointer_t> + Data map(); + }; + + /* + Buffer + */ + + struct Buffer final { + /* + Descriptor + */ + + struct Descriptor final { + VkDeviceSize size; + + struct { + VkBufferUsageFlags buffer; + VmaMemoryUsage memory; + } usage; + }; + + VkBuffer handle; + Memory memory; + + operator bool() const; + }; + + /* + Image + */ + + struct Image final { + /* + Descriptor + */ + + struct Descriptor final { + VkImageType type; + VkFormat format; + VkExtent3D extent; + + struct { + VkImageUsageFlags image; + VmaMemoryUsage memory; + } usage; + + struct { + VkImageViewType type; + VkFormat format; + } view; + }; + + VkImage handle; + VkImageView view; + Memory memory; + + operator bool() const; + }; + + /* + Pool + */ + + class Pool final { + public: + Pool( + VkInstance instance, + VkPhysicalDevice physical_device, + VkDevice device); + + Buffer allocate(const Buffer::Descriptor& descriptor); + Image allocate(const Image::Descriptor& descriptor); + void purge(); + + private: + struct Configuration final { + static constexpr uint32_t kReserve = 256u; + }; + + VkDevice device_; + Handle allocator_; + std::vector> buffers_; + std::vector> images_; + } pool; + + Resource( + const VkInstance instance, + const VkPhysicalDevice physical_device, + const VkDevice device) + : pool(instance, physical_device, device) { + } +}; + +// +// Impl +// + +class Resource::Memory::Scope final { + public: + enum class Access { + Read, + Write, + }; + + Scope(VmaAllocator allocator, VmaAllocation allocation, Access access); + void operator()(const void* data) const; + + private: + VmaAllocator allocator_; + VmaAllocation allocation_; + Access access_; +}; + +template +inline Resource::Memory::Data Resource::Memory::map() const { + void* map(const Memory& memory); + + return Data{ + reinterpret_cast(map(*this)), + Scope(allocator, allocation, Scope::Access::Read), + }; +} + +template +inline Resource::Memory::Data Resource::Memory::map() { + void* map(const Memory& memory); + + return Data{ + reinterpret_cast(map(*this)), + Scope(allocator, allocation, Scope::Access::Write), + }; +} + +inline Resource::Buffer::operator bool() const { + return VK_NULL_HANDLE != handle; +} + +inline Resource::Image::operator bool() const { + return VK_NULL_HANDLE != handle; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Shader.cpp b/aten/src/ATen/native/vulkan/api/Shader.cpp new file mode 100644 index 000000000000000..bbd3e3464d78fc6 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Shader.cpp @@ -0,0 +1,153 @@ +#include + +#ifdef USE_VULKAN_SHADERC_RUNTIME +#include +#endif /* USE_VULKAN_SHADERC_RUNTIME */ + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +Shader::Layout::Factory::Factory(const VkDevice device) + : device_(device) { +} + +Shader::Layout::Factory::Handle Shader::Layout::Factory::operator()( + const Descriptor& descriptor) const { + const VkDescriptorSetLayoutCreateInfo descriptor_set_layout_create_info{ + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO, + nullptr, + 0u, + static_cast(descriptor.bindings.size()), + descriptor.bindings.data(), + }; + + VkDescriptorSetLayout descriptor_set_layout{}; + VK_CHECK(vkCreateDescriptorSetLayout( + device_, &descriptor_set_layout_create_info, nullptr, &descriptor_set_layout)); + + return Handle{ + descriptor_set_layout, + Deleter(device_), + }; +} + +Shader::Descriptor::Descriptor(const char* const glsl) + : type(Type::Source) { + shader.source = { + glsl, + 0u, + }; +} + +Shader::Descriptor::Descriptor(const uint32_t* const code, const uint32_t size) + : type(Type::Binary) { + shader.binary = { + code, + size, + }; +} + +#ifdef USE_VULKAN_SHADERC_RUNTIME + +struct Shader::Factory::Compiler final { + shaderc::Compiler context; + shaderc::CompileOptions options; + + Compiler() { + options.SetSourceLanguage(shaderc_source_language_glsl); + options.SetTargetEnvironment(shaderc_target_env_vulkan, shaderc_env_version_vulkan_1_0); + options.SetWarningsAsErrors(); + #ifdef DEBUG + options.SetGenerateDebugInfo(); + options.SetOptimizationLevel(shaderc_optimization_level_zero); + #else + options.SetOptimizationLevel(shaderc_optimization_level_performance); + #endif /* DEBUG */ + } + + std::vector compile(const char* const source) const { + const shaderc::SpvCompilationResult result = context.CompileGlslToSpv( + source, + ::strlen(source), + shaderc_compute_shader, + "vulkan_shader.comp", + options); + + const shaderc_compilation_status status = result.GetCompilationStatus(); + TORCH_INTERNAL_ASSERT( + shaderc_compilation_status_success == status, + "Shader compilation error: ", + result.GetErrorMessage()); + + return std::vector(result.cbegin(), result.cend()); + } +}; + +#else + +struct Shader::Factory::Compiler final { + std::vector compile(const char* const /* source */) const { + return std::vector{}; + } +}; + +#endif /* USE_VULKAN_SHADERC_RUNTIME */ + +Shader::Factory::Factory(const VkDevice device) + : device_(device), + compiler_(new Compiler) { +} + +// std::unique_ptr requires its template parameter to be fully defined. +// For that reason pimpl through unique_ptr requires the definition of +// the [default] constructor and move assignment operator to appear after +// impl is fully defined. + +Shader::Factory::Factory(Factory&&) = default; +Shader::Factory& Shader::Factory::Factory::operator=(Factory&&) = default; +Shader::Factory::~Factory() = default; + +typename Shader::Factory::Handle Shader::Factory::operator()( + const Descriptor& descriptor) const { + std::vector binary; + + const uint32_t* code = nullptr; + uint32_t size = 0u; + + if (Descriptor::Type::Source == descriptor.type) { + binary = compiler_->compile(descriptor.shader.source.glsl); + code = binary.data(); + size = sizeof(uint32_t) * static_cast(binary.size()); + } + else if (Descriptor::Type::Binary == descriptor.type) { + code = descriptor.shader.binary.spirv; + size = descriptor.shader.binary.size; + } + else { + TORCH_INTERNAL_ASSERT(false, "Invalid descriptor type!"); + } + + const VkShaderModuleCreateInfo shader_module_create_info{ + VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO, + nullptr, + 0u, + size, + code, + }; + + VkShaderModule shader_module{}; + VK_CHECK(vkCreateShaderModule( + device_, &shader_module_create_info, nullptr, &shader_module)); + + return Handle{ + shader_module, + Deleter(device_), + }; +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/vulkan/api/Shader.h b/aten/src/ATen/native/vulkan/api/Shader.h new file mode 100644 index 000000000000000..0fd2fa01614bbc3 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/Shader.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace native { +namespace vulkan { +namespace api { + +// +// This struct defines shader, and shader layout, caches intended to minimize +// redundant object reconstructions at the cost of extra memory consumption. +// +// A shader is a small, usually simple, program that typically runs on a GPU as +// part of the graphics or compute pipelines. The shader layout defines the +// interface between that program and the outside world, namely what the host +// (i.e. CPU) sees as configurable parameters of the said shader per dispatch. +// If the shader was a regular function, the shader layout would have been its +// function prototype declaring the number and type of its arguments. +// +// Furthermore, shader layouts, or as Vulkan calls them descriptor set layouts, +// define the blueprint out of which descriptor sets are instantiated. Descriptor +// sets themselves, bundle the input to and output from a shader and contain +// pointers to GPU, and GPU accessible system, memory locations where the actual +// resources reside. Shader layouts are also used in creation of Vulkan pipeline +// layouts, while multiple shaders are bundled together to form a portion of the +// the monolithic state objects that are Vulkan pipelines. +// +// This struct defines the facilities required to create, compile, reuse, +// and destruct the aforementioned Vulkan objects. +// + +struct C10_EXPORT Shader final { + // + // Layout + // + + struct Layout final { + /* + Descriptor + */ + + struct Descriptor final { + c10::SmallVector bindings; + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + + typedef Layout::Descriptor Descriptor; + typedef VK_DELETER(DescriptorSetLayout) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Layout(const VkDevice device) + : cache(Factory(device)) { + } + } layout; + + // + // Work Group + // + + struct WorkGroup final { + uint32_t x; + uint32_t y; + uint32_t z; + }; + + /* + Descriptor + */ + + struct Descriptor final { + enum class Type { + Source, + Binary, + } type; + + union { + struct { + const char* glsl; // Null-terminated + uint32_t unused; // Padding + } source; + + struct { + const uint32_t* spirv; + uint32_t size; // Bytes + } binary; + } shader; + + Descriptor(const char* glsl); + Descriptor(const uint32_t* spirv, uint32_t bytes); + }; + + /* + Factory + */ + + class Factory final { + public: + explicit Factory(VkDevice device); + Factory(const Factory&) = delete; + Factory& operator=(const Factory&) = delete; + Factory(Factory&&); + Factory& operator=(Factory&&); + ~Factory(); + + typedef Shader::Descriptor Descriptor; + typedef VK_DELETER(ShaderModule) Deleter; + typedef Handle Handle; + + struct Hasher { + size_t operator()(const Descriptor& descriptor) const; + }; + + Handle operator()(const Descriptor& descriptor) const; + + private: + VkDevice device_; + struct Compiler; + std::unique_ptr compiler_; + }; + + /* + Cache + */ + + typedef api::Cache Cache; + Cache cache; + + explicit Shader(const VkDevice device) + : layout(device), + cache(Factory(device)) { + } +}; + +// +// Impl +// + +inline bool operator==( + const Shader::Layout::Descriptor& _1, + const Shader::Layout::Descriptor& _2) { + return _1.bindings == _2.bindings; +} + +inline size_t Shader::Layout::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + size_t hash = 0u; + + for (const VkDescriptorSetLayoutBinding& binding : descriptor.bindings) { + hash = c10::hash_combine( + hash, + c10::get_hash( + binding.binding, + binding.descriptorType, + binding.descriptorCount, + binding.stageFlags, + binding.pImmutableSamplers)); + } + + return hash; +} + +inline bool operator==( + const Shader::WorkGroup& work_group_1, + const Shader::WorkGroup& work_group_2) { + return (work_group_1.x == work_group_2.x) && + (work_group_1.y == work_group_2.y) && + (work_group_1.z == work_group_2.z); +} + +inline bool operator==( + const Shader::Descriptor& _1, + const Shader::Descriptor& _2) { + static_assert( + sizeof(Shader::Descriptor::shader.source) == sizeof(Shader::Descriptor::shader.binary), + "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + + return (_1.type == _2.type) && + (_1.shader.binary.spirv == _2.shader.binary.spirv) && + (_1.shader.binary.size == _2.shader.binary.size); +} + +inline size_t Shader::Factory::Hasher::operator()( + const Descriptor& descriptor) const { + static_assert( + sizeof(Descriptor::shader.source) == sizeof(Descriptor::shader.binary), + "This implementation requires sizeof(Source) to be equal to sizeof(Binary)."); + + return c10::get_hash( + descriptor.type, + descriptor.shader.binary.spirv, + descriptor.shader.binary.size); +} + +} // namespace api +} // namespace vulkan +} // namespace native +} // namespace at + +inline bool operator==( + const VkDescriptorSetLayoutBinding& _1, + const VkDescriptorSetLayoutBinding& _2) { + return (_1.binding == _2.binding) && + (_1.descriptorType == _2.descriptorType) && + (_1.descriptorCount == _2.descriptorCount) && + (_1.stageFlags == _2.stageFlags) && + (_1.pImmutableSamplers == _2.pImmutableSamplers); +} diff --git a/aten/src/ATen/native/vulkan/api/api.h b/aten/src/ATen/native/vulkan/api/api.h new file mode 100644 index 000000000000000..394f55d7d525ee5 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/api.h @@ -0,0 +1,10 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include diff --git a/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h new file mode 100644 index 000000000000000..fdeadf9cdbfa754 --- /dev/null +++ b/aten/src/ATen/native/vulkan/api/vk_mem_alloc.h @@ -0,0 +1,19074 @@ +// +// Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +// + +#ifndef AMD_VULKAN_MEMORY_ALLOCATOR_H +#define AMD_VULKAN_MEMORY_ALLOCATOR_H + +/** \mainpage Vulkan Memory Allocator + +Version 3.0.0-development (2020-06-24) + +Copyright (c) 2017-2020 Advanced Micro Devices, Inc. All rights reserved. \n +License: MIT + +Documentation of all members: vk_mem_alloc.h + +\section main_table_of_contents Table of contents + +- User guide + - \subpage quick_start + - [Project setup](@ref quick_start_project_setup) + - [Initialization](@ref quick_start_initialization) + - [Resource allocation](@ref quick_start_resource_allocation) + - \subpage choosing_memory_type + - [Usage](@ref choosing_memory_type_usage) + - [Required and preferred flags](@ref choosing_memory_type_required_preferred_flags) + - [Explicit memory types](@ref choosing_memory_type_explicit_memory_types) + - [Custom memory pools](@ref choosing_memory_type_custom_memory_pools) + - [Dedicated allocations](@ref choosing_memory_type_dedicated_allocations) + - \subpage memory_mapping + - [Mapping functions](@ref memory_mapping_mapping_functions) + - [Persistently mapped memory](@ref memory_mapping_persistently_mapped_memory) + - [Cache flush and invalidate](@ref memory_mapping_cache_control) + - [Finding out if memory is mappable](@ref memory_mapping_finding_if_memory_mappable) + - \subpage staying_within_budget + - [Querying for budget](@ref staying_within_budget_querying_for_budget) + - [Controlling memory usage](@ref staying_within_budget_controlling_memory_usage) + - \subpage custom_memory_pools + - [Choosing memory type index](@ref custom_memory_pools_MemTypeIndex) + - [Linear allocation algorithm](@ref linear_algorithm) + - [Free-at-once](@ref linear_algorithm_free_at_once) + - [Stack](@ref linear_algorithm_stack) + - [Double stack](@ref linear_algorithm_double_stack) + - [Ring buffer](@ref linear_algorithm_ring_buffer) + - [Buddy allocation algorithm](@ref buddy_algorithm) + - \subpage defragmentation + - [Defragmenting CPU memory](@ref defragmentation_cpu) + - [Defragmenting GPU memory](@ref defragmentation_gpu) + - [Additional notes](@ref defragmentation_additional_notes) + - [Writing custom allocation algorithm](@ref defragmentation_custom_algorithm) + - \subpage lost_allocations + - \subpage statistics + - [Numeric statistics](@ref statistics_numeric_statistics) + - [JSON dump](@ref statistics_json_dump) + - \subpage allocation_annotation + - [Allocation user data](@ref allocation_user_data) + - [Allocation names](@ref allocation_names) + - \subpage debugging_memory_usage + - [Memory initialization](@ref debugging_memory_usage_initialization) + - [Margins](@ref debugging_memory_usage_margins) + - [Corruption detection](@ref debugging_memory_usage_corruption_detection) + - \subpage record_and_replay +- \subpage usage_patterns + - [Common mistakes](@ref usage_patterns_common_mistakes) + - [Simple patterns](@ref usage_patterns_simple) + - [Advanced patterns](@ref usage_patterns_advanced) +- \subpage configuration + - [Pointers to Vulkan functions](@ref config_Vulkan_functions) + - [Custom host memory allocator](@ref custom_memory_allocator) + - [Device memory allocation callbacks](@ref allocation_callbacks) + - [Device heap memory limit](@ref heap_memory_limit) + - \subpage vk_khr_dedicated_allocation + - \subpage enabling_buffer_device_address + - \subpage vk_amd_device_coherent_memory +- \subpage general_considerations + - [Thread safety](@ref general_considerations_thread_safety) + - [Validation layer warnings](@ref general_considerations_validation_layer_warnings) + - [Allocation algorithm](@ref general_considerations_allocation_algorithm) + - [Features not supported](@ref general_considerations_features_not_supported) + +\section main_see_also See also + +- [Product page on GPUOpen](https://gpuopen.com/gaming-product/vulkan-memory-allocator/) +- [Source repository on GitHub](https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator) + + + + +\page quick_start Quick start + +\section quick_start_project_setup Project setup + +Vulkan Memory Allocator comes in form of a "stb-style" single header file. +You don't need to build it as a separate library project. +You can add this file directly to your project and submit it to code repository next to your other source files. + +"Single header" doesn't mean that everything is contained in C/C++ declarations, +like it tends to be in case of inline functions or C++ templates. +It means that implementation is bundled with interface in a single file and needs to be extracted using preprocessor macro. +If you don't do it properly, you will get linker errors. + +To do it properly: + +-# Include "vk_mem_alloc.h" file in each CPP file where you want to use the library. + This includes declarations of all members of the library. +-# In exacly one CPP file define following macro before this include. + It enables also internal definitions. + +\code +#define VMA_IMPLEMENTATION +#include vk_mem_alloc.h +\endcode + +It may be a good idea to create dedicated CPP file just for this purpose. + +Note on language: This library is written in C++, but has C-compatible interface. +Thus you can include and use vk_mem_alloc.h in C or C++ code, but full +implementation with `VMA_IMPLEMENTATION` macro must be compiled as C++, NOT as C. + +Please note that this library includes header ``, which in turn +includes `` on Windows. If you need some specific macros defined +before including these headers (like `WIN32_LEAN_AND_MEAN` or +`WINVER` for Windows, `VK_USE_PLATFORM_WIN32_KHR` for Vulkan), you must define +them before every `#include` of this library. + + +\section quick_start_initialization Initialization + +At program startup: + +-# Initialize Vulkan to have `VkPhysicalDevice`, `VkDevice` and `VkInstance` object. +-# Fill VmaAllocatorCreateInfo structure and create #VmaAllocator object by + calling vmaCreateAllocator(). + +\code +VmaAllocatorCreateInfo allocatorInfo = {}; +allocatorInfo.physicalDevice = physicalDevice; +allocatorInfo.device = device; +allocatorInfo.instance = instance; + +VmaAllocator allocator; +vmaCreateAllocator(&allocatorInfo, &allocator); +\endcode + +\section quick_start_resource_allocation Resource allocation + +When you want to create a buffer or image: + +-# Fill `VkBufferCreateInfo` / `VkImageCreateInfo` structure. +-# Fill VmaAllocationCreateInfo structure. +-# Call vmaCreateBuffer() / vmaCreateImage() to get `VkBuffer`/`VkImage` with memory + already allocated and bound to it. + +\code +VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufferInfo.size = 65536; +bufferInfo.usage = VK_BUFFER_USAGE_VERTEX_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +Don't forget to destroy your objects when no longer needed: + +\code +vmaDestroyBuffer(allocator, buffer, allocation); +vmaDestroyAllocator(allocator); +\endcode + + +\page choosing_memory_type Choosing memory type + +Physical devices in Vulkan support various combinations of memory heaps and +types. Help with choosing correct and optimal memory type for your specific +resource is one of the key features of this library. You can use it by filling +appropriate members of VmaAllocationCreateInfo structure, as described below. +You can also combine multiple methods. + +-# If you just want to find memory type index that meets your requirements, you + can use function: vmaFindMemoryTypeIndex(), vmaFindMemoryTypeIndexForBufferInfo(), + vmaFindMemoryTypeIndexForImageInfo(). +-# If you want to allocate a region of device memory without association with any + specific image or buffer, you can use function vmaAllocateMemory(). Usage of + this function is not recommended and usually not needed. + vmaAllocateMemoryPages() function is also provided for creating multiple allocations at once, + which may be useful for sparse binding. +-# If you already have a buffer or an image created, you want to allocate memory + for it and then you will bind it yourself, you can use function + vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(). + For binding you should use functions: vmaBindBufferMemory(), vmaBindImageMemory() + or their extended versions: vmaBindBufferMemory2(), vmaBindImageMemory2(). +-# If you want to create a buffer or an image, allocate memory for it and bind + them together, all in one call, you can use function vmaCreateBuffer(), + vmaCreateImage(). This is the easiest and recommended way to use this library. + +When using 3. or 4., the library internally queries Vulkan for memory types +supported for that buffer or image (function `vkGetBufferMemoryRequirements()`) +and uses only one of these types. + +If no memory type can be found that meets all the requirements, these functions +return `VK_ERROR_FEATURE_NOT_PRESENT`. + +You can leave VmaAllocationCreateInfo structure completely filled with zeros. +It means no requirements are specified for memory type. +It is valid, although not very useful. + +\section choosing_memory_type_usage Usage + +The easiest way to specify memory requirements is to fill member +VmaAllocationCreateInfo::usage using one of the values of enum #VmaMemoryUsage. +It defines high level, common usage types. +For more details, see description of this enum. + +For example, if you want to create a uniform buffer that will be filled using +transfer only once or infrequently and used for rendering every frame, you can +do it using following code: + +\code +VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufferInfo.size = 65536; +bufferInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +\section choosing_memory_type_required_preferred_flags Required and preferred flags + +You can specify more detailed requirements by filling members +VmaAllocationCreateInfo::requiredFlags and VmaAllocationCreateInfo::preferredFlags +with a combination of bits from enum `VkMemoryPropertyFlags`. For example, +if you want to create a buffer that will be persistently mapped on host (so it +must be `HOST_VISIBLE`) and preferably will also be `HOST_COHERENT` and `HOST_CACHED`, +use following code: + +\code +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; +allocInfo.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; +allocInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +A memory type is chosen that has all the required flags and as many preferred +flags set as possible. + +If you use VmaAllocationCreateInfo::usage, it is just internally converted to +a set of required and preferred flags. + +\section choosing_memory_type_explicit_memory_types Explicit memory types + +If you inspected memory types available on the physical device and you have +a preference for memory types that you want to use, you can fill member +VmaAllocationCreateInfo::memoryTypeBits. It is a bit mask, where each bit set +means that a memory type with that index is allowed to be used for the +allocation. Special value 0, just like `UINT32_MAX`, means there are no +restrictions to memory type index. + +Please note that this member is NOT just a memory type index. +Still you can use it to choose just one, specific memory type. +For example, if you already determined that your buffer should be created in +memory type 2, use following code: + +\code +uint32_t memoryTypeIndex = 2; + +VmaAllocationCreateInfo allocInfo = {}; +allocInfo.memoryTypeBits = 1u << memoryTypeIndex; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocInfo, &buffer, &allocation, nullptr); +\endcode + +\section choosing_memory_type_custom_memory_pools Custom memory pools + +If you allocate from custom memory pool, all the ways of specifying memory +requirements described above are not applicable and the aforementioned members +of VmaAllocationCreateInfo structure are ignored. Memory type is selected +explicitly when creating the pool and then used to make all the allocations from +that pool. For further details, see \ref custom_memory_pools. + +\section choosing_memory_type_dedicated_allocations Dedicated allocations + +Memory for allocations is reserved out of larger block of `VkDeviceMemory` +allocated from Vulkan internally. That's the main feature of this whole library. +You can still request a separate memory block to be created for an allocation, +just like you would do in a trivial solution without using any allocator. +In that case, a buffer or image is always bound to that memory at offset 0. +This is called a "dedicated allocation". +You can explicitly request it by using flag #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +The library can also internally decide to use dedicated allocation in some cases, e.g.: + +- When the size of the allocation is large. +- When [VK_KHR_dedicated_allocation](@ref vk_khr_dedicated_allocation) extension is enabled + and it reports that dedicated allocation is required or recommended for the resource. +- When allocation of next big memory block fails due to not enough device memory, + but allocation with the exact requested size succeeds. + + +\page memory_mapping Memory mapping + +To "map memory" in Vulkan means to obtain a CPU pointer to `VkDeviceMemory`, +to be able to read from it or write to it in CPU code. +Mapping is possible only of memory allocated from a memory type that has +`VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` flag. +Functions `vkMapMemory()`, `vkUnmapMemory()` are designed for this purpose. +You can use them directly with memory allocated by this library, +but it is not recommended because of following issue: +Mapping the same `VkDeviceMemory` block multiple times is illegal - only one mapping at a time is allowed. +This includes mapping disjoint regions. Mapping is not reference-counted internally by Vulkan. +Because of this, Vulkan Memory Allocator provides following facilities: + +\section memory_mapping_mapping_functions Mapping functions + +The library provides following functions for mapping of a specific #VmaAllocation: vmaMapMemory(), vmaUnmapMemory(). +They are safer and more convenient to use than standard Vulkan functions. +You can map an allocation multiple times simultaneously - mapping is reference-counted internally. +You can also map different allocations simultaneously regardless of whether they use the same `VkDeviceMemory` block. +The way it's implemented is that the library always maps entire memory block, not just region of the allocation. +For further details, see description of vmaMapMemory() function. +Example: + +\code +// Having these objects initialized: + +struct ConstantBuffer +{ + ... +}; +ConstantBuffer constantBufferData; + +VmaAllocator allocator; +VkBuffer constantBuffer; +VmaAllocation constantBufferAllocation; + +// You can map and fill your buffer using following code: + +void* mappedData; +vmaMapMemory(allocator, constantBufferAllocation, &mappedData); +memcpy(mappedData, &constantBufferData, sizeof(constantBufferData)); +vmaUnmapMemory(allocator, constantBufferAllocation); +\endcode + +When mapping, you may see a warning from Vulkan validation layer similar to this one: + +Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used. + +It happens because the library maps entire `VkDeviceMemory` block, where different +types of images and buffers may end up together, especially on GPUs with unified memory like Intel. +You can safely ignore it if you are sure you access only memory of the intended +object that you wanted to map. + + +\section memory_mapping_persistently_mapped_memory Persistently mapped memory + +Kepping your memory persistently mapped is generally OK in Vulkan. +You don't need to unmap it before using its data on the GPU. +The library provides a special feature designed for that: +Allocations made with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag set in +VmaAllocationCreateInfo::flags stay mapped all the time, +so you can just access CPU pointer to it any time +without a need to call any "map" or "unmap" function. +Example: + +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = sizeof(ConstantBuffer); +bufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_CPU_ONLY; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + +// Buffer is already mapped. You can access its memory. +memcpy(allocInfo.pMappedData, &constantBufferData, sizeof(constantBufferData)); +\endcode + +There are some exceptions though, when you should consider mapping memory only for a short period of time: + +- When operating system is Windows 7 or 8.x (Windows 10 is not affected because it uses WDDM2), + device is discrete AMD GPU, + and memory type is the special 256 MiB pool of `DEVICE_LOCAL + HOST_VISIBLE` memory + (selected when you use #VMA_MEMORY_USAGE_CPU_TO_GPU), + then whenever a memory block allocated from this memory type stays mapped + for the time of any call to `vkQueueSubmit()` or `vkQueuePresentKHR()`, this + block is migrated by WDDM to system RAM, which degrades performance. It doesn't + matter if that particular memory block is actually used by the command buffer + being submitted. +- On Mac/MoltenVK there is a known bug - [Issue #175](https://github.com/KhronosGroup/MoltenVK/issues/175) + which requires unmapping before GPU can see updated texture. +- Keeping many large memory blocks mapped may impact performance or stability of some debugging tools. + +\section memory_mapping_cache_control Cache flush and invalidate + +Memory in Vulkan doesn't need to be unmapped before using it on GPU, +but unless a memory types has `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT` flag set, +you need to manually **invalidate** cache before reading of mapped pointer +and **flush** cache after writing to mapped pointer. +Map/unmap operations don't do that automatically. +Vulkan provides following functions for this purpose `vkFlushMappedMemoryRanges()`, +`vkInvalidateMappedMemoryRanges()`, but this library provides more convenient +functions that refer to given allocation object: vmaFlushAllocation(), +vmaInvalidateAllocation(), +or multiple objects at once: vmaFlushAllocations(), vmaInvalidateAllocations(). + +Regions of memory specified for flush/invalidate must be aligned to +`VkPhysicalDeviceLimits::nonCoherentAtomSize`. This is automatically ensured by the library. +In any memory type that is `HOST_VISIBLE` but not `HOST_COHERENT`, all allocations +within blocks are aligned to this value, so their offsets are always multiply of +`nonCoherentAtomSize` and two different allocations never share same "line" of this size. + +Please note that memory allocated with #VMA_MEMORY_USAGE_CPU_ONLY is guaranteed to be `HOST_COHERENT`. + +Also, Windows drivers from all 3 **PC** GPU vendors (AMD, Intel, NVIDIA) +currently provide `HOST_COHERENT` flag on all memory types that are +`HOST_VISIBLE`, so on this platform you may not need to bother. + +\section memory_mapping_finding_if_memory_mappable Finding out if memory is mappable + +It may happen that your allocation ends up in memory that is `HOST_VISIBLE` (available for mapping) +despite it wasn't explicitly requested. +For example, application may work on integrated graphics with unified memory (like Intel) or +allocation from video memory might have failed, so the library chose system memory as fallback. + +You can detect this case and map such allocation to access its memory on CPU directly, +instead of launching a transfer operation. +In order to do that: inspect `allocInfo.memoryType`, call vmaGetMemoryTypeProperties(), +and look for `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` flag in properties of that memory type. + +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = sizeof(ConstantBuffer); +bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +allocCreateInfo.preferredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + +VkMemoryPropertyFlags memFlags; +vmaGetMemoryTypeProperties(allocator, allocInfo.memoryType, &memFlags); +if((memFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) +{ + // Allocation ended up in mappable memory. You can map it and access it directly. + void* mappedData; + vmaMapMemory(allocator, alloc, &mappedData); + memcpy(mappedData, &constantBufferData, sizeof(constantBufferData)); + vmaUnmapMemory(allocator, alloc); +} +else +{ + // Allocation ended up in non-mappable memory. + // You need to create CPU-side buffer in VMA_MEMORY_USAGE_CPU_ONLY and make a transfer. +} +\endcode + +You can even use #VMA_ALLOCATION_CREATE_MAPPED_BIT flag while creating allocations +that are not necessarily `HOST_VISIBLE` (e.g. using #VMA_MEMORY_USAGE_GPU_ONLY). +If the allocation ends up in memory type that is `HOST_VISIBLE`, it will be persistently mapped and you can use it directly. +If not, the flag is just ignored. +Example: + +\code +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = sizeof(ConstantBuffer); +bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_MAPPED_BIT; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); + +if(allocInfo.pUserData != nullptr) +{ + // Allocation ended up in mappable memory. + // It's persistently mapped. You can access it directly. + memcpy(allocInfo.pMappedData, &constantBufferData, sizeof(constantBufferData)); +} +else +{ + // Allocation ended up in non-mappable memory. + // You need to create CPU-side buffer in VMA_MEMORY_USAGE_CPU_ONLY and make a transfer. +} +\endcode + + +\page staying_within_budget Staying within budget + +When developing a graphics-intensive game or program, it is important to avoid allocating +more GPU memory than it's physically available. When the memory is over-committed, +various bad things can happen, depending on the specific GPU, graphics driver, and +operating system: + +- It may just work without any problems. +- The application may slow down because some memory blocks are moved to system RAM + and the GPU has to access them through PCI Express bus. +- A new allocation may take very long time to complete, even few seconds, and possibly + freeze entire system. +- The new allocation may fail with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +- It may even result in GPU crash (TDR), observed as `VK_ERROR_DEVICE_LOST` + returned somewhere later. + +\section staying_within_budget_querying_for_budget Querying for budget + +To query for current memory usage and available budget, use function vmaGetBudget(). +Returned structure #VmaBudget contains quantities expressed in bytes, per Vulkan memory heap. + +Please note that this function returns different information and works faster than +vmaCalculateStats(). vmaGetBudget() can be called every frame or even before every +allocation, while vmaCalculateStats() is intended to be used rarely, +only to obtain statistical information, e.g. for debugging purposes. + +It is recommended to use VK_EXT_memory_budget device extension to obtain information +about the budget from Vulkan device. VMA is able to use this extension automatically. +When not enabled, the allocator behaves same way, but then it estimates current usage +and available budget based on its internal information and Vulkan memory heap sizes, +which may be less precise. In order to use this extension: + +1. Make sure extensions VK_EXT_memory_budget and VK_KHR_get_physical_device_properties2 + required by it are available and enable them. Please note that the first is a device + extension and the second is instance extension! +2. Use flag #VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT when creating #VmaAllocator object. +3. Make sure to call vmaSetCurrentFrameIndex() every frame. Budget is queried from + Vulkan inside of it to avoid overhead of querying it with every allocation. + +\section staying_within_budget_controlling_memory_usage Controlling memory usage + +There are many ways in which you can try to stay within the budget. + +First, when making new allocation requires allocating a new memory block, the library +tries not to exceed the budget automatically. If a block with default recommended size +(e.g. 256 MB) would go over budget, a smaller block is allocated, possibly even +dedicated memory for just this resource. + +If the size of the requested resource plus current memory usage is more than the +budget, by default the library still tries to create it, leaving it to the Vulkan +implementation whether the allocation succeeds or fails. You can change this behavior +by using #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag. With it, the allocation is +not made if it would exceed the budget or if the budget is already exceeded. +Some other allocations become lost instead to make room for it, if the mechanism of +[lost allocations](@ref lost_allocations) is used. +If that is not possible, the allocation fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +Example usage pattern may be to pass the #VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT flag +when creating resources that are not essential for the application (e.g. the texture +of a specific object) and not to pass it when creating critically important resources +(e.g. render targets). + +Finally, you can also use #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT flag to make sure +a new allocation is created only when it fits inside one of the existing memory blocks. +If it would require to allocate a new block, if fails instead with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. +This also ensures that the function call is very fast because it never goes to Vulkan +to obtain a new block. + +Please note that creating \ref custom_memory_pools with VmaPoolCreateInfo::minBlockCount +set to more than 0 will try to allocate memory blocks without checking whether they +fit within budget. + + +\page custom_memory_pools Custom memory pools + +A memory pool contains a number of `VkDeviceMemory` blocks. +The library automatically creates and manages default pool for each memory type available on the device. +Default memory pool automatically grows in size. +Size of allocated blocks is also variable and managed automatically. + +You can create custom pool and allocate memory out of it. +It can be useful if you want to: + +- Keep certain kind of allocations separate from others. +- Enforce particular, fixed size of Vulkan memory blocks. +- Limit maximum amount of Vulkan memory allocated for that pool. +- Reserve minimum or fixed amount of Vulkan memory always preallocated for that pool. + +To use custom memory pools: + +-# Fill VmaPoolCreateInfo structure. +-# Call vmaCreatePool() to obtain #VmaPool handle. +-# When making an allocation, set VmaAllocationCreateInfo::pool to this handle. + You don't need to specify any other parameters of this structure, like `usage`. + +Example: + +\code +// Create a pool that can have at most 2 blocks, 128 MiB each. +VmaPoolCreateInfo poolCreateInfo = {}; +poolCreateInfo.memoryTypeIndex = ... +poolCreateInfo.blockSize = 128ull * 1024 * 1024; +poolCreateInfo.maxBlockCount = 2; + +VmaPool pool; +vmaCreatePool(allocator, &poolCreateInfo, &pool); + +// Allocate a buffer out of it. +VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +bufCreateInfo.size = 1024; +bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.pool = pool; + +VkBuffer buf; +VmaAllocation alloc; +VmaAllocationInfo allocInfo; +vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &buf, &alloc, &allocInfo); +\endcode + +You have to free all allocations made from this pool before destroying it. + +\code +vmaDestroyBuffer(allocator, buf, alloc); +vmaDestroyPool(allocator, pool); +\endcode + +\section custom_memory_pools_MemTypeIndex Choosing memory type index + +When creating a pool, you must explicitly specify memory type index. +To find the one suitable for your buffers or images, you can use helper functions +vmaFindMemoryTypeIndexForBufferInfo(), vmaFindMemoryTypeIndexForImageInfo(). +You need to provide structures with example parameters of buffers or images +that you are going to create in that pool. + +\code +VkBufferCreateInfo exampleBufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +exampleBufCreateInfo.size = 1024; // Whatever. +exampleBufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; // Change if needed. + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; // Change if needed. + +uint32_t memTypeIndex; +vmaFindMemoryTypeIndexForBufferInfo(allocator, &exampleBufCreateInfo, &allocCreateInfo, &memTypeIndex); + +VmaPoolCreateInfo poolCreateInfo = {}; +poolCreateInfo.memoryTypeIndex = memTypeIndex; +// ... +\endcode + +When creating buffers/images allocated in that pool, provide following parameters: + +- `VkBufferCreateInfo`: Prefer to pass same parameters as above. + Otherwise you risk creating resources in a memory type that is not suitable for them, which may result in undefined behavior. + Using different `VK_BUFFER_USAGE_` flags may work, but you shouldn't create images in a pool intended for buffers + or the other way around. +- VmaAllocationCreateInfo: You don't need to pass same parameters. Fill only `pool` member. + Other members are ignored anyway. + +\section linear_algorithm Linear allocation algorithm + +Each Vulkan memory block managed by this library has accompanying metadata that +keeps track of used and unused regions. By default, the metadata structure and +algorithm tries to find best place for new allocations among free regions to +optimize memory usage. This way you can allocate and free objects in any order. + +![Default allocation algorithm](../gfx/Linear_allocator_1_algo_default.png) + +Sometimes there is a need to use simpler, linear allocation algorithm. You can +create custom pool that uses such algorithm by adding flag +#VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT to VmaPoolCreateInfo::flags while creating +#VmaPool object. Then an alternative metadata management is used. It always +creates new allocations after last one and doesn't reuse free regions after +allocations freed in the middle. It results in better allocation performance and +less memory consumed by metadata. + +![Linear allocation algorithm](../gfx/Linear_allocator_2_algo_linear.png) + +With this one flag, you can create a custom pool that can be used in many ways: +free-at-once, stack, double stack, and ring buffer. See below for details. + +\subsection linear_algorithm_free_at_once Free-at-once + +In a pool that uses linear algorithm, you still need to free all the allocations +individually, e.g. by using vmaFreeMemory() or vmaDestroyBuffer(). You can free +them in any order. New allocations are always made after last one - free space +in the middle is not reused. However, when you release all the allocation and +the pool becomes empty, allocation starts from the beginning again. This way you +can use linear algorithm to speed up creation of allocations that you are going +to release all at once. + +![Free-at-once](../gfx/Linear_allocator_3_free_at_once.png) + +This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount +value that allows multiple memory blocks. + +\subsection linear_algorithm_stack Stack + +When you free an allocation that was created last, its space can be reused. +Thanks to this, if you always release allocations in the order opposite to their +creation (LIFO - Last In First Out), you can achieve behavior of a stack. + +![Stack](../gfx/Linear_allocator_4_stack.png) + +This mode is also available for pools created with VmaPoolCreateInfo::maxBlockCount +value that allows multiple memory blocks. + +\subsection linear_algorithm_double_stack Double stack + +The space reserved by a custom pool with linear algorithm may be used by two +stacks: + +- First, default one, growing up from offset 0. +- Second, "upper" one, growing down from the end towards lower offsets. + +To make allocation from upper stack, add flag #VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT +to VmaAllocationCreateInfo::flags. + +![Double stack](../gfx/Linear_allocator_7_double_stack.png) + +Double stack is available only in pools with one memory block - +VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. + +When the two stacks' ends meet so there is not enough space between them for a +new allocation, such allocation fails with usual +`VK_ERROR_OUT_OF_DEVICE_MEMORY` error. + +\subsection linear_algorithm_ring_buffer Ring buffer + +When you free some allocations from the beginning and there is not enough free space +for a new one at the end of a pool, allocator's "cursor" wraps around to the +beginning and starts allocation there. Thanks to this, if you always release +allocations in the same order as you created them (FIFO - First In First Out), +you can achieve behavior of a ring buffer / queue. + +![Ring buffer](../gfx/Linear_allocator_5_ring_buffer.png) + +Pools with linear algorithm support [lost allocations](@ref lost_allocations) when used as ring buffer. +If there is not enough free space for a new allocation, but existing allocations +from the front of the queue can become lost, they become lost and the allocation +succeeds. + +![Ring buffer with lost allocations](../gfx/Linear_allocator_6_ring_buffer_lost.png) + +Ring buffer is available only in pools with one memory block - +VmaPoolCreateInfo::maxBlockCount must be 1. Otherwise behavior is undefined. + +\section buddy_algorithm Buddy allocation algorithm + +There is another allocation algorithm that can be used with custom pools, called +"buddy". Its internal data structure is based on a tree of blocks, each having +size that is a power of two and a half of its parent's size. When you want to +allocate memory of certain size, a free node in the tree is located. If it's too +large, it is recursively split into two halves (called "buddies"). However, if +requested allocation size is not a power of two, the size of a tree node is +aligned up to the nearest power of two and the remaining space is wasted. When +two buddy nodes become free, they are merged back into one larger node. + +![Buddy allocator](../gfx/Buddy_allocator.png) + +The advantage of buddy allocation algorithm over default algorithm is faster +allocation and deallocation, as well as smaller external fragmentation. The +disadvantage is more wasted space (internal fragmentation). + +For more information, please read ["Buddy memory allocation" on Wikipedia](https://en.wikipedia.org/wiki/Buddy_memory_allocation) +or other sources that describe this concept in general. + +To use buddy allocation algorithm with a custom pool, add flag +#VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT to VmaPoolCreateInfo::flags while creating +#VmaPool object. + +Several limitations apply to pools that use buddy algorithm: + +- It is recommended to use VmaPoolCreateInfo::blockSize that is a power of two. + Otherwise, only largest power of two smaller than the size is used for + allocations. The remaining space always stays unused. +- [Margins](@ref debugging_memory_usage_margins) and + [corruption detection](@ref debugging_memory_usage_corruption_detection) + don't work in such pools. +- [Lost allocations](@ref lost_allocations) don't work in such pools. You can + use them, but they never become lost. Support may be added in the future. +- [Defragmentation](@ref defragmentation) doesn't work with allocations made from + such pool. + +\page defragmentation Defragmentation + +Interleaved allocations and deallocations of many objects of varying size can +cause fragmentation over time, which can lead to a situation where the library is unable +to find a continuous range of free memory for a new allocation despite there is +enough free space, just scattered across many small free ranges between existing +allocations. + +To mitigate this problem, you can use defragmentation feature: +structure #VmaDefragmentationInfo2, function vmaDefragmentationBegin(), vmaDefragmentationEnd(). +Given set of allocations, +this function can move them to compact used memory, ensure more continuous free +space and possibly also free some `VkDeviceMemory` blocks. + +What the defragmentation does is: + +- Updates #VmaAllocation objects to point to new `VkDeviceMemory` and offset. + After allocation has been moved, its VmaAllocationInfo::deviceMemory and/or + VmaAllocationInfo::offset changes. You must query them again using + vmaGetAllocationInfo() if you need them. +- Moves actual data in memory. + +What it doesn't do, so you need to do it yourself: + +- Recreate buffers and images that were bound to allocations that were defragmented and + bind them with their new places in memory. + You must use `vkDestroyBuffer()`, `vkDestroyImage()`, + `vkCreateBuffer()`, `vkCreateImage()`, vmaBindBufferMemory(), vmaBindImageMemory() + for that purpose and NOT vmaDestroyBuffer(), + vmaDestroyImage(), vmaCreateBuffer(), vmaCreateImage(), because you don't need to + destroy or create allocation objects! +- Recreate views and update descriptors that point to these buffers and images. + +\section defragmentation_cpu Defragmenting CPU memory + +Following example demonstrates how you can run defragmentation on CPU. +Only allocations created in memory types that are `HOST_VISIBLE` can be defragmented. +Others are ignored. + +The way it works is: + +- It temporarily maps entire memory blocks when necessary. +- It moves data using `memmove()` function. + +\code +// Given following variables already initialized: +VkDevice device; +VmaAllocator allocator; +std::vector buffers; +std::vector allocations; + + +const uint32_t allocCount = (uint32_t)allocations.size(); +std::vector allocationsChanged(allocCount); + +VmaDefragmentationInfo2 defragInfo = {}; +defragInfo.allocationCount = allocCount; +defragInfo.pAllocations = allocations.data(); +defragInfo.pAllocationsChanged = allocationsChanged.data(); +defragInfo.maxCpuBytesToMove = VK_WHOLE_SIZE; // No limit. +defragInfo.maxCpuAllocationsToMove = UINT32_MAX; // No limit. + +VmaDefragmentationContext defragCtx; +vmaDefragmentationBegin(allocator, &defragInfo, nullptr, &defragCtx); +vmaDefragmentationEnd(allocator, defragCtx); + +for(uint32_t i = 0; i < allocCount; ++i) +{ + if(allocationsChanged[i]) + { + // Destroy buffer that is immutably bound to memory region which is no longer valid. + vkDestroyBuffer(device, buffers[i], nullptr); + + // Create new buffer with same parameters. + VkBufferCreateInfo bufferInfo = ...; + vkCreateBuffer(device, &bufferInfo, nullptr, &buffers[i]); + + // You can make dummy call to vkGetBufferMemoryRequirements here to silence validation layer warning. + + // Bind new buffer to new memory region. Data contained in it is already moved. + VmaAllocationInfo allocInfo; + vmaGetAllocationInfo(allocator, allocations[i], &allocInfo); + vmaBindBufferMemory(allocator, allocations[i], buffers[i]); + } +} +\endcode + +Setting VmaDefragmentationInfo2::pAllocationsChanged is optional. +This output array tells whether particular allocation in VmaDefragmentationInfo2::pAllocations at the same index +has been modified during defragmentation. +You can pass null, but you then need to query every allocation passed to defragmentation +for new parameters using vmaGetAllocationInfo() if you might need to recreate and rebind a buffer or image associated with it. + +If you use [Custom memory pools](@ref choosing_memory_type_custom_memory_pools), +you can fill VmaDefragmentationInfo2::poolCount and VmaDefragmentationInfo2::pPools +instead of VmaDefragmentationInfo2::allocationCount and VmaDefragmentationInfo2::pAllocations +to defragment all allocations in given pools. +You cannot use VmaDefragmentationInfo2::pAllocationsChanged in that case. +You can also combine both methods. + +\section defragmentation_gpu Defragmenting GPU memory + +It is also possible to defragment allocations created in memory types that are not `HOST_VISIBLE`. +To do that, you need to pass a command buffer that meets requirements as described in +VmaDefragmentationInfo2::commandBuffer. The way it works is: + +- It creates temporary buffers and binds them to entire memory blocks when necessary. +- It issues `vkCmdCopyBuffer()` to passed command buffer. + +Example: + +\code +// Given following variables already initialized: +VkDevice device; +VmaAllocator allocator; +VkCommandBuffer commandBuffer; +std::vector buffers; +std::vector allocations; + + +const uint32_t allocCount = (uint32_t)allocations.size(); +std::vector allocationsChanged(allocCount); + +VkCommandBufferBeginInfo cmdBufBeginInfo = ...; +vkBeginCommandBuffer(commandBuffer, &cmdBufBeginInfo); + +VmaDefragmentationInfo2 defragInfo = {}; +defragInfo.allocationCount = allocCount; +defragInfo.pAllocations = allocations.data(); +defragInfo.pAllocationsChanged = allocationsChanged.data(); +defragInfo.maxGpuBytesToMove = VK_WHOLE_SIZE; // Notice it's "GPU" this time. +defragInfo.maxGpuAllocationsToMove = UINT32_MAX; // Notice it's "GPU" this time. +defragInfo.commandBuffer = commandBuffer; + +VmaDefragmentationContext defragCtx; +vmaDefragmentationBegin(allocator, &defragInfo, nullptr, &defragCtx); + +vkEndCommandBuffer(commandBuffer); + +// Submit commandBuffer. +// Wait for a fence that ensures commandBuffer execution finished. + +vmaDefragmentationEnd(allocator, defragCtx); + +for(uint32_t i = 0; i < allocCount; ++i) +{ + if(allocationsChanged[i]) + { + // Destroy buffer that is immutably bound to memory region which is no longer valid. + vkDestroyBuffer(device, buffers[i], nullptr); + + // Create new buffer with same parameters. + VkBufferCreateInfo bufferInfo = ...; + vkCreateBuffer(device, &bufferInfo, nullptr, &buffers[i]); + + // You can make dummy call to vkGetBufferMemoryRequirements here to silence validation layer warning. + + // Bind new buffer to new memory region. Data contained in it is already moved. + VmaAllocationInfo allocInfo; + vmaGetAllocationInfo(allocator, allocations[i], &allocInfo); + vmaBindBufferMemory(allocator, allocations[i], buffers[i]); + } +} +\endcode + +You can combine these two methods by specifying non-zero `maxGpu*` as well as `maxCpu*` parameters. +The library automatically chooses best method to defragment each memory pool. + +You may try not to block your entire program to wait until defragmentation finishes, +but do it in the background, as long as you carefully fullfill requirements described +in function vmaDefragmentationBegin(). + +\section defragmentation_additional_notes Additional notes + +It is only legal to defragment allocations bound to: + +- buffers +- images created with `VK_IMAGE_CREATE_ALIAS_BIT`, `VK_IMAGE_TILING_LINEAR`, and + being currently in `VK_IMAGE_LAYOUT_GENERAL` or `VK_IMAGE_LAYOUT_PREINITIALIZED`. + +Defragmentation of images created with `VK_IMAGE_TILING_OPTIMAL` or in any other +layout may give undefined results. + +If you defragment allocations bound to images, new images to be bound to new +memory region after defragmentation should be created with `VK_IMAGE_LAYOUT_PREINITIALIZED` +and then transitioned to their original layout from before defragmentation if +needed using an image memory barrier. + +While using defragmentation, you may experience validation layer warnings, which you just need to ignore. +See [Validation layer warnings](@ref general_considerations_validation_layer_warnings). + +Please don't expect memory to be fully compacted after defragmentation. +Algorithms inside are based on some heuristics that try to maximize number of Vulkan +memory blocks to make totally empty to release them, as well as to maximimze continuous +empty space inside remaining blocks, while minimizing the number and size of allocations that +need to be moved. Some fragmentation may still remain - this is normal. + +\section defragmentation_custom_algorithm Writing custom defragmentation algorithm + +If you want to implement your own, custom defragmentation algorithm, +there is infrastructure prepared for that, +but it is not exposed through the library API - you need to hack its source code. +Here are steps needed to do this: + +-# Main thing you need to do is to define your own class derived from base abstract + class `VmaDefragmentationAlgorithm` and implement your version of its pure virtual methods. + See definition and comments of this class for details. +-# Your code needs to interact with device memory block metadata. + If you need more access to its data than it's provided by its public interface, + declare your new class as a friend class e.g. in class `VmaBlockMetadata_Generic`. +-# If you want to create a flag that would enable your algorithm or pass some additional + flags to configure it, add them to `VmaDefragmentationFlagBits` and use them in + VmaDefragmentationInfo2::flags. +-# Modify function `VmaBlockVectorDefragmentationContext::Begin` to create object + of your new class whenever needed. + + +\page lost_allocations Lost allocations + +If your game oversubscribes video memory, if may work OK in previous-generation +graphics APIs (DirectX 9, 10, 11, OpenGL) because resources are automatically +paged to system RAM. In Vulkan you can't do it because when you run out of +memory, an allocation just fails. If you have more data (e.g. textures) that can +fit into VRAM and you don't need it all at once, you may want to upload them to +GPU on demand and "push out" ones that are not used for a long time to make room +for the new ones, effectively using VRAM (or a cartain memory pool) as a form of +cache. Vulkan Memory Allocator can help you with that by supporting a concept of +"lost allocations". + +To create an allocation that can become lost, include #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT +flag in VmaAllocationCreateInfo::flags. Before using a buffer or image bound to +such allocation in every new frame, you need to query it if it's not lost. +To check it, call vmaTouchAllocation(). +If the allocation is lost, you should not use it or buffer/image bound to it. +You mustn't forget to destroy this allocation and this buffer/image. +vmaGetAllocationInfo() can also be used for checking status of the allocation. +Allocation is lost when returned VmaAllocationInfo::deviceMemory == `VK_NULL_HANDLE`. + +To create an allocation that can make some other allocations lost to make room +for it, use #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag. You will +usually use both flags #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT and +#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT at the same time. + +Warning! Current implementation uses quite naive, brute force algorithm, +which can make allocation calls that use #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT +flag quite slow. A new, more optimal algorithm and data structure to speed this +up is planned for the future. + +Q: When interleaving creation of new allocations with usage of existing ones, +how do you make sure that an allocation won't become lost while it's used in the +current frame? + +It is ensured because vmaTouchAllocation() / vmaGetAllocationInfo() not only returns allocation +status/parameters and checks whether it's not lost, but when it's not, it also +atomically marks it as used in the current frame, which makes it impossible to +become lost in that frame. It uses lockless algorithm, so it works fast and +doesn't involve locking any internal mutex. + +Q: What if my allocation may still be in use by the GPU when it's rendering a +previous frame while I already submit new frame on the CPU? + +You can make sure that allocations "touched" by vmaTouchAllocation() / vmaGetAllocationInfo() will not +become lost for a number of additional frames back from the current one by +specifying this number as VmaAllocatorCreateInfo::frameInUseCount (for default +memory pool) and VmaPoolCreateInfo::frameInUseCount (for custom pool). + +Q: How do you inform the library when new frame starts? + +You need to call function vmaSetCurrentFrameIndex(). + +Example code: + +\code +struct MyBuffer +{ + VkBuffer m_Buf = nullptr; + VmaAllocation m_Alloc = nullptr; + + // Called when the buffer is really needed in the current frame. + void EnsureBuffer(); +}; + +void MyBuffer::EnsureBuffer() +{ + // Buffer has been created. + if(m_Buf != VK_NULL_HANDLE) + { + // Check if its allocation is not lost + mark it as used in current frame. + if(vmaTouchAllocation(allocator, m_Alloc)) + { + // It's all OK - safe to use m_Buf. + return; + } + } + + // Buffer not yet exists or lost - destroy and recreate it. + + vmaDestroyBuffer(allocator, m_Buf, m_Alloc); + + VkBufferCreateInfo bufCreateInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; + bufCreateInfo.size = 1024; + bufCreateInfo.usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + + VmaAllocationCreateInfo allocCreateInfo = {}; + allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; + allocCreateInfo.flags = VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT | + VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT; + + vmaCreateBuffer(allocator, &bufCreateInfo, &allocCreateInfo, &m_Buf, &m_Alloc, nullptr); +} +\endcode + +When using lost allocations, you may see some Vulkan validation layer warnings +about overlapping regions of memory bound to different kinds of buffers and +images. This is still valid as long as you implement proper handling of lost +allocations (like in the example above) and don't use them. + +You can create an allocation that is already in lost state from the beginning using function +vmaCreateLostAllocation(). It may be useful if you need a "dummy" allocation that is not null. + +You can call function vmaMakePoolAllocationsLost() to set all eligible allocations +in a specified custom pool to lost state. +Allocations that have been "touched" in current frame or VmaPoolCreateInfo::frameInUseCount frames back +cannot become lost. + +Q: Can I touch allocation that cannot become lost? + +Yes, although it has no visible effect. +Calls to vmaGetAllocationInfo() and vmaTouchAllocation() update last use frame index +also for allocations that cannot become lost, but the only way to observe it is to dump +internal allocator state using vmaBuildStatsString(). +You can use this feature for debugging purposes to explicitly mark allocations that you use +in current frame and then analyze JSON dump to see for how long each allocation stays unused. + + +\page statistics Statistics + +This library contains functions that return information about its internal state, +especially the amount of memory allocated from Vulkan. +Please keep in mind that these functions need to traverse all internal data structures +to gather these information, so they may be quite time-consuming. +Don't call them too often. + +\section statistics_numeric_statistics Numeric statistics + +You can query for overall statistics of the allocator using function vmaCalculateStats(). +Information are returned using structure #VmaStats. +It contains #VmaStatInfo - number of allocated blocks, number of allocations +(occupied ranges in these blocks), number of unused (free) ranges in these blocks, +number of bytes used and unused (but still allocated from Vulkan) and other information. +They are summed across memory heaps, memory types and total for whole allocator. + +You can query for statistics of a custom pool using function vmaGetPoolStats(). +Information are returned using structure #VmaPoolStats. + +You can query for information about specific allocation using function vmaGetAllocationInfo(). +It fill structure #VmaAllocationInfo. + +\section statistics_json_dump JSON dump + +You can dump internal state of the allocator to a string in JSON format using function vmaBuildStatsString(). +The result is guaranteed to be correct JSON. +It uses ANSI encoding. +Any strings provided by user (see [Allocation names](@ref allocation_names)) +are copied as-is and properly escaped for JSON, so if they use UTF-8, ISO-8859-2 or any other encoding, +this JSON string can be treated as using this encoding. +It must be freed using function vmaFreeStatsString(). + +The format of this JSON string is not part of official documentation of the library, +but it will not change in backward-incompatible way without increasing library major version number +and appropriate mention in changelog. + +The JSON string contains all the data that can be obtained using vmaCalculateStats(). +It can also contain detailed map of allocated memory blocks and their regions - +free and occupied by allocations. +This allows e.g. to visualize the memory or assess fragmentation. + + +\page allocation_annotation Allocation names and user data + +\section allocation_user_data Allocation user data + +You can annotate allocations with your own information, e.g. for debugging purposes. +To do that, fill VmaAllocationCreateInfo::pUserData field when creating +an allocation. It's an opaque `void*` pointer. You can use it e.g. as a pointer, +some handle, index, key, ordinal number or any other value that would associate +the allocation with your custom metadata. + +\code +VkBufferCreateInfo bufferInfo = { VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO }; +// Fill bufferInfo... + +MyBufferMetadata* pMetadata = CreateBufferMetadata(); + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +allocCreateInfo.pUserData = pMetadata; + +VkBuffer buffer; +VmaAllocation allocation; +vmaCreateBuffer(allocator, &bufferInfo, &allocCreateInfo, &buffer, &allocation, nullptr); +\endcode + +The pointer may be later retrieved as VmaAllocationInfo::pUserData: + +\code +VmaAllocationInfo allocInfo; +vmaGetAllocationInfo(allocator, allocation, &allocInfo); +MyBufferMetadata* pMetadata = (MyBufferMetadata*)allocInfo.pUserData; +\endcode + +It can also be changed using function vmaSetAllocationUserData(). + +Values of (non-zero) allocations' `pUserData` are printed in JSON report created by +vmaBuildStatsString(), in hexadecimal form. + +\section allocation_names Allocation names + +There is alternative mode available where `pUserData` pointer is used to point to +a null-terminated string, giving a name to the allocation. To use this mode, +set #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT flag in VmaAllocationCreateInfo::flags. +Then `pUserData` passed as VmaAllocationCreateInfo::pUserData or argument to +vmaSetAllocationUserData() must be either null or pointer to a null-terminated string. +The library creates internal copy of the string, so the pointer you pass doesn't need +to be valid for whole lifetime of the allocation. You can free it after the call. + +\code +VkImageCreateInfo imageInfo = { VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO }; +// Fill imageInfo... + +std::string imageName = "Texture: "; +imageName += fileName; + +VmaAllocationCreateInfo allocCreateInfo = {}; +allocCreateInfo.usage = VMA_MEMORY_USAGE_GPU_ONLY; +allocCreateInfo.flags = VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT; +allocCreateInfo.pUserData = imageName.c_str(); + +VkImage image; +VmaAllocation allocation; +vmaCreateImage(allocator, &imageInfo, &allocCreateInfo, &image, &allocation, nullptr); +\endcode + +The value of `pUserData` pointer of the allocation will be different than the one +you passed when setting allocation's name - pointing to a buffer managed +internally that holds copy of the string. + +\code +VmaAllocationInfo allocInfo; +vmaGetAllocationInfo(allocator, allocation, &allocInfo); +const char* imageName = (const char*)allocInfo.pUserData; +printf("Image name: %s\n", imageName); +\endcode + +That string is also printed in JSON report created by vmaBuildStatsString(). + +\note Passing string name to VMA allocation doesn't automatically set it to the Vulkan buffer or image created with it. +You must do it manually using an extension like VK_EXT_debug_utils, which is independent of this library. + + +\page debugging_memory_usage Debugging incorrect memory usage + +If you suspect a bug with memory usage, like usage of uninitialized memory or +memory being overwritten out of bounds of an allocation, +you can use debug features of this library to verify this. + +\section debugging_memory_usage_initialization Memory initialization + +If you experience a bug with incorrect and nondeterministic data in your program and you suspect uninitialized memory to be used, +you can enable automatic memory initialization to verify this. +To do it, define macro `VMA_DEBUG_INITIALIZE_ALLOCATIONS` to 1. + +\code +#define VMA_DEBUG_INITIALIZE_ALLOCATIONS 1 +#include vk_mem_alloc.h +\endcode + +It makes memory of all new allocations initialized to bit pattern `0xDCDCDCDC`. +Before an allocation is destroyed, its memory is filled with bit pattern `0xEFEFEFEF`. +Memory is automatically mapped and unmapped if necessary. + +If you find these values while debugging your program, good chances are that you incorrectly +read Vulkan memory that is allocated but not initialized, or already freed, respectively. + +Memory initialization works only with memory types that are `HOST_VISIBLE`. +It works also with dedicated allocations. +It doesn't work with allocations created with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, +as they cannot be mapped. + +\section debugging_memory_usage_margins Margins + +By default, allocations are laid out in memory blocks next to each other if possible +(considering required alignment, `bufferImageGranularity`, and `nonCoherentAtomSize`). + +![Allocations without margin](../gfx/Margins_1.png) + +Define macro `VMA_DEBUG_MARGIN` to some non-zero value (e.g. 16) to enforce specified +number of bytes as a margin before and after every allocation. + +\code +#define VMA_DEBUG_MARGIN 16 +#include vk_mem_alloc.h +\endcode + +![Allocations with margin](../gfx/Margins_2.png) + +If your bug goes away after enabling margins, it means it may be caused by memory +being overwritten outside of allocation boundaries. It is not 100% certain though. +Change in application behavior may also be caused by different order and distribution +of allocations across memory blocks after margins are applied. + +The margin is applied also before first and after last allocation in a block. +It may occur only once between two adjacent allocations. + +Margins work with all types of memory. + +Margin is applied only to allocations made out of memory blocks and not to dedicated +allocations, which have their own memory block of specific size. +It is thus not applied to allocations made using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT flag +or those automatically decided to put into dedicated allocations, e.g. due to its +large size or recommended by VK_KHR_dedicated_allocation extension. +Margins are also not active in custom pools created with #VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT flag. + +Margins appear in [JSON dump](@ref statistics_json_dump) as part of free space. + +Note that enabling margins increases memory usage and fragmentation. + +\section debugging_memory_usage_corruption_detection Corruption detection + +You can additionally define macro `VMA_DEBUG_DETECT_CORRUPTION` to 1 to enable validation +of contents of the margins. + +\code +#define VMA_DEBUG_MARGIN 16 +#define VMA_DEBUG_DETECT_CORRUPTION 1 +#include vk_mem_alloc.h +\endcode + +When this feature is enabled, number of bytes specified as `VMA_DEBUG_MARGIN` +(it must be multiply of 4) before and after every allocation is filled with a magic number. +This idea is also know as "canary". +Memory is automatically mapped and unmapped if necessary. + +This number is validated automatically when the allocation is destroyed. +If it's not equal to the expected value, `VMA_ASSERT()` is executed. +It clearly means that either CPU or GPU overwritten the memory outside of boundaries of the allocation, +which indicates a serious bug. + +You can also explicitly request checking margins of all allocations in all memory blocks +that belong to specified memory types by using function vmaCheckCorruption(), +or in memory blocks that belong to specified custom pool, by using function +vmaCheckPoolCorruption(). + +Margin validation (corruption detection) works only for memory types that are +`HOST_VISIBLE` and `HOST_COHERENT`. + + +\page record_and_replay Record and replay + +\section record_and_replay_introduction Introduction + +While using the library, sequence of calls to its functions together with their +parameters can be recorded to a file and later replayed using standalone player +application. It can be useful to: + +- Test correctness - check if same sequence of calls will not cause crash or + failures on a target platform. +- Gather statistics - see number of allocations, peak memory usage, number of + calls etc. +- Benchmark performance - see how much time it takes to replay the whole + sequence. + +\section record_and_replay_usage Usage + +Recording functionality is disabled by default. +To enable it, define following macro before every include of this library: + +\code +#define VMA_RECORDING_ENABLED 1 +\endcode + +To record sequence of calls to a file: Fill in +VmaAllocatorCreateInfo::pRecordSettings member while creating #VmaAllocator +object. File is opened and written during whole lifetime of the allocator. + +To replay file: Use VmaReplay - standalone command-line program. +Precompiled binary can be found in "bin" directory. +Its source can be found in "src/VmaReplay" directory. +Its project is generated by Premake. +Command line syntax is printed when the program is launched without parameters. +Basic usage: + + VmaReplay.exe MyRecording.csv + +Documentation of file format can be found in file: "docs/Recording file format.md". +It's a human-readable, text file in CSV format (Comma Separated Values). + +\section record_and_replay_additional_considerations Additional considerations + +- Replaying file that was recorded on a different GPU (with different parameters + like `bufferImageGranularity`, `nonCoherentAtomSize`, and especially different + set of memory heaps and types) may give different performance and memory usage + results, as well as issue some warnings and errors. +- Current implementation of recording in VMA, as well as VmaReplay application, is + coded and tested only on Windows. Inclusion of recording code is driven by + `VMA_RECORDING_ENABLED` macro. Support for other platforms should be easy to + add. Contributions are welcomed. + + +\page usage_patterns Recommended usage patterns + +See also slides from talk: +[Sawicki, Adam. Advanced Graphics Techniques Tutorial: Memory management in Vulkan and DX12. Game Developers Conference, 2018](https://www.gdcvault.com/play/1025458/Advanced-Graphics-Techniques-Tutorial-New) + + +\section usage_patterns_common_mistakes Common mistakes + +Use of CPU_TO_GPU instead of CPU_ONLY memory + +#VMA_MEMORY_USAGE_CPU_TO_GPU is recommended only for resources that will be +mapped and written by the CPU, as well as read directly by the GPU - like some +buffers or textures updated every frame (dynamic). If you create a staging copy +of a resource to be written by CPU and then used as a source of transfer to +another resource placed in the GPU memory, that staging resource should be +created with #VMA_MEMORY_USAGE_CPU_ONLY. Please read the descriptions of these +enums carefully for details. + +Unnecessary use of custom pools + +\ref custom_memory_pools may be useful for special purposes - when you want to +keep certain type of resources separate e.g. to reserve minimum amount of memory +for them, limit maximum amount of memory they can occupy, or make some of them +push out the other through the mechanism of \ref lost_allocations. For most +resources this is not needed and so it is not recommended to create #VmaPool +objects and allocations out of them. Allocating from the default pool is sufficient. + +\section usage_patterns_simple Simple patterns + +\subsection usage_patterns_simple_render_targets Render targets + +When: +Any resources that you frequently write and read on GPU, +e.g. images used as color attachments (aka "render targets"), depth-stencil attachments, +images/buffers used as storage image/buffer (aka "Unordered Access View (UAV)"). + +What to do: +Create them in video memory that is fastest to access from GPU using +#VMA_MEMORY_USAGE_GPU_ONLY. + +Consider using [VK_KHR_dedicated_allocation](@ref vk_khr_dedicated_allocation) extension +and/or manually creating them as dedicated allocations using #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT, +especially if they are large or if you plan to destroy and recreate them e.g. when +display resolution changes. +Prefer to create such resources first and all other GPU resources (like textures and vertex buffers) later. + +\subsection usage_patterns_simple_immutable_resources Immutable resources + +When: +Any resources that you fill on CPU only once (aka "immutable") or infrequently +and then read frequently on GPU, +e.g. textures, vertex and index buffers, constant buffers that don't change often. + +What to do: +Create them in video memory that is fastest to access from GPU using +#VMA_MEMORY_USAGE_GPU_ONLY. + +To initialize content of such resource, create a CPU-side (aka "staging") copy of it +in system memory - #VMA_MEMORY_USAGE_CPU_ONLY, map it, fill it, +and submit a transfer from it to the GPU resource. +You can keep the staging copy if you need it for another upload transfer in the future. +If you don't, you can destroy it or reuse this buffer for uploading different resource +after the transfer finishes. + +Prefer to create just buffers in system memory rather than images, even for uploading textures. +Use `vkCmdCopyBufferToImage()`. +Dont use images with `VK_IMAGE_TILING_LINEAR`. + +\subsection usage_patterns_dynamic_resources Dynamic resources + +When: +Any resources that change frequently (aka "dynamic"), e.g. every frame or every draw call, +written on CPU, read on GPU. + +What to do: +Create them using #VMA_MEMORY_USAGE_CPU_TO_GPU. +You can map it and write to it directly on CPU, as well as read from it on GPU. + +This is a more complex situation. Different solutions are possible, +and the best one depends on specific GPU type, but you can use this simple approach for the start. +Prefer to write to such resource sequentially (e.g. using `memcpy`). +Don't perform random access or any reads from it on CPU, as it may be very slow. +Also note that textures written directly from the host through a mapped pointer need to be in LINEAR not OPTIMAL layout. + +\subsection usage_patterns_readback Readback + +When: +Resources that contain data written by GPU that you want to read back on CPU, +e.g. results of some computations. + +What to do: +Create them using #VMA_MEMORY_USAGE_GPU_TO_CPU. +You can write to them directly on GPU, as well as map and read them on CPU. + +\section usage_patterns_advanced Advanced patterns + +\subsection usage_patterns_integrated_graphics Detecting integrated graphics + +You can support integrated graphics (like Intel HD Graphics, AMD APU) better +by detecting it in Vulkan. +To do it, call `vkGetPhysicalDeviceProperties()`, inspect +`VkPhysicalDeviceProperties::deviceType` and look for `VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU`. +When you find it, you can assume that memory is unified and all memory types are comparably fast +to access from GPU, regardless of `VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT`. + +You can then sum up sizes of all available memory heaps and treat them as useful for +your GPU resources, instead of only `DEVICE_LOCAL` ones. +You can also prefer to create your resources in memory types that are `HOST_VISIBLE` to map them +directly instead of submitting explicit transfer (see below). + +\subsection usage_patterns_direct_vs_transfer Direct access versus transfer + +For resources that you frequently write on CPU and read on GPU, many solutions are possible: + +-# Create one copy in video memory using #VMA_MEMORY_USAGE_GPU_ONLY, + second copy in system memory using #VMA_MEMORY_USAGE_CPU_ONLY and submit explicit transfer each time. +-# Create just a single copy using #VMA_MEMORY_USAGE_CPU_TO_GPU, map it and fill it on CPU, + read it directly on GPU. +-# Create just a single copy using #VMA_MEMORY_USAGE_CPU_ONLY, map it and fill it on CPU, + read it directly on GPU. + +Which solution is the most efficient depends on your resource and especially on the GPU. +It is best to measure it and then make the decision. +Some general recommendations: + +- On integrated graphics use (2) or (3) to avoid unnecesary time and memory overhead + related to using a second copy and making transfer. +- For small resources (e.g. constant buffers) use (2). + Discrete AMD cards have special 256 MiB pool of video memory that is directly mappable. + Even if the resource ends up in system memory, its data may be cached on GPU after first + fetch over PCIe bus. +- For larger resources (e.g. textures), decide between (1) and (2). + You may want to differentiate NVIDIA and AMD, e.g. by looking for memory type that is + both `DEVICE_LOCAL` and `HOST_VISIBLE`. When you find it, use (2), otherwise use (1). + +Similarly, for resources that you frequently write on GPU and read on CPU, multiple +solutions are possible: + +-# Create one copy in video memory using #VMA_MEMORY_USAGE_GPU_ONLY, + second copy in system memory using #VMA_MEMORY_USAGE_GPU_TO_CPU and submit explicit tranfer each time. +-# Create just single copy using #VMA_MEMORY_USAGE_GPU_TO_CPU, write to it directly on GPU, + map it and read it on CPU. + +You should take some measurements to decide which option is faster in case of your specific +resource. + +Note that textures accessed directly from the host through a mapped pointer need to be in LINEAR layout, +which may slow down their usage on the device. +Textures accessed only by the device and transfer operations can use OPTIMAL layout. + +If you don't want to specialize your code for specific types of GPUs, you can still make +an simple optimization for cases when your resource ends up in mappable memory to use it +directly in this case instead of creating CPU-side staging copy. +For details see [Finding out if memory is mappable](@ref memory_mapping_finding_if_memory_mappable). + + +\page configuration Configuration + +Please check "CONFIGURATION SECTION" in the code to find macros that you can define +before each include of this file or change directly in this file to provide +your own implementation of basic facilities like assert, `min()` and `max()` functions, +mutex, atomic etc. +The library uses its own implementation of containers by default, but you can switch to using +STL containers instead. + +For example, define `VMA_ASSERT(expr)` before including the library to provide +custom implementation of the assertion, compatible with your project. +By default it is defined to standard C `assert(expr)` in `_DEBUG` configuration +and empty otherwise. + +\section config_Vulkan_functions Pointers to Vulkan functions + +There are multiple ways to import pointers to Vulkan functions in the library. +In the simplest case you don't need to do anything. +If the compilation or linking of your program or the initialization of the #VmaAllocator +doesn't work for you, you can try to reconfigure it. + +First, the allocator tries to fetch pointers to Vulkan functions linked statically, +like this: + +\code +m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; +\endcode + +If you want to disable this feature, set configuration macro: `#define VMA_STATIC_VULKAN_FUNCTIONS 0`. + +Second, you can provide the pointers yourself by setting member VmaAllocatorCreateInfo::pVulkanFunctions. +You can fetch them e.g. using functions `vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` or +by using a helper library like [volk](https://github.com/zeux/volk). + +Third, VMA tries to fetch remaining pointers that are still null by calling +`vkGetInstanceProcAddr` and `vkGetDeviceProcAddr` on its own. +If you want to disable this feature, set configuration macro: `#define VMA_DYNAMIC_VULKAN_FUNCTIONS 0`. + +Finally, all the function pointers required by the library (considering selected +Vulkan version and enabled extensions) are checked with `VMA_ASSERT` if they are not null. + + +\section custom_memory_allocator Custom host memory allocator + +If you use custom allocator for CPU memory rather than default operator `new` +and `delete` from C++, you can make this library using your allocator as well +by filling optional member VmaAllocatorCreateInfo::pAllocationCallbacks. These +functions will be passed to Vulkan, as well as used by the library itself to +make any CPU-side allocations. + +\section allocation_callbacks Device memory allocation callbacks + +The library makes calls to `vkAllocateMemory()` and `vkFreeMemory()` internally. +You can setup callbacks to be informed about these calls, e.g. for the purpose +of gathering some statistics. To do it, fill optional member +VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. + +\section heap_memory_limit Device heap memory limit + +When device memory of certain heap runs out of free space, new allocations may +fail (returning error code) or they may succeed, silently pushing some existing +memory blocks from GPU VRAM to system RAM (which degrades performance). This +behavior is implementation-dependant - it depends on GPU vendor and graphics +driver. + +On AMD cards it can be controlled while creating Vulkan device object by using +VK_AMD_memory_overallocation_behavior extension, if available. + +Alternatively, if you want to test how your program behaves with limited amount of Vulkan device +memory available without switching your graphics card to one that really has +smaller VRAM, you can use a feature of this library intended for this purpose. +To do it, fill optional member VmaAllocatorCreateInfo::pHeapSizeLimit. + + + +\page vk_khr_dedicated_allocation VK_KHR_dedicated_allocation + +VK_KHR_dedicated_allocation is a Vulkan extension which can be used to improve +performance on some GPUs. It augments Vulkan API with possibility to query +driver whether it prefers particular buffer or image to have its own, dedicated +allocation (separate `VkDeviceMemory` block) for better efficiency - to be able +to do some internal optimizations. + +The extension is supported by this library. It will be used automatically when +enabled. To enable it: + +1 . When creating Vulkan device, check if following 2 device extensions are +supported (call `vkEnumerateDeviceExtensionProperties()`). +If yes, enable them (fill `VkDeviceCreateInfo::ppEnabledExtensionNames`). + +- VK_KHR_get_memory_requirements2 +- VK_KHR_dedicated_allocation + +If you enabled these extensions: + +2 . Use #VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag when creating +your #VmaAllocator`to inform the library that you enabled required extensions +and you want the library to use them. + +\code +allocatorInfo.flags |= VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT; + +vmaCreateAllocator(&allocatorInfo, &allocator); +\endcode + +That's all. The extension will be automatically used whenever you create a +buffer using vmaCreateBuffer() or image using vmaCreateImage(). + +When using the extension together with Vulkan Validation Layer, you will receive +warnings like this: + + vkBindBufferMemory(): Binding memory to buffer 0x33 but vkGetBufferMemoryRequirements() has not been called on that buffer. + +It is OK, you should just ignore it. It happens because you use function +`vkGetBufferMemoryRequirements2KHR()` instead of standard +`vkGetBufferMemoryRequirements()`, while the validation layer seems to be +unaware of it. + +To learn more about this extension, see: + +- [VK_KHR_dedicated_allocation in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap44.html#VK_KHR_dedicated_allocation) +- [VK_KHR_dedicated_allocation unofficial manual](http://asawicki.info/articles/VK_KHR_dedicated_allocation.php5) + + + +\page vk_amd_device_coherent_memory VK_AMD_device_coherent_memory + +VK_AMD_device_coherent_memory is a device extension that enables access to +additional memory types with `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and +`VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flag. It is useful mostly for +allocation of buffers intended for writing "breadcrumb markers" in between passes +or draw calls, which in turn are useful for debugging GPU crash/hang/TDR cases. + +When the extension is available but has not been enabled, Vulkan physical device +still exposes those memory types, but their usage is forbidden. VMA automatically +takes care of that - it returns `VK_ERROR_FEATURE_NOT_PRESENT` when an attempt +to allocate memory of such type is made. + +If you want to use this extension in connection with VMA, follow these steps: + +\section vk_amd_device_coherent_memory_initialization Initialization + +1) Call `vkEnumerateDeviceExtensionProperties` for the physical device. +Check if the extension is supported - if returned array of `VkExtensionProperties` contains "VK_AMD_device_coherent_memory". + +2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. +Attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to `VkPhysicalDeviceFeatures2::pNext` to be returned. +Check if the device feature is really supported - check if `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true. + +3) While creating device with `vkCreateDevice`, enable this extension - add "VK_AMD_device_coherent_memory" +to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. + +4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. +Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. +Enable this device feature - attach additional structure `VkPhysicalDeviceCoherentMemoryFeaturesAMD` to +`VkPhysicalDeviceFeatures2::pNext` and set its member `deviceCoherentMemory` to `VK_TRUE`. + +5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you +have enabled this extension and feature - add #VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT +to VmaAllocatorCreateInfo::flags. + +\section vk_amd_device_coherent_memory_usage Usage + +After following steps described above, you can create VMA allocations and custom pools +out of the special `DEVICE_COHERENT` and `DEVICE_UNCACHED` memory types on eligible +devices. There are multiple ways to do it, for example: + +- You can request or prefer to allocate out of such memory types by adding + `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` to VmaAllocationCreateInfo::requiredFlags + or VmaAllocationCreateInfo::preferredFlags. Those flags can be freely mixed with + other ways of \ref choosing_memory_type, like setting VmaAllocationCreateInfo::usage. +- If you manually found memory type index to use for this purpose, force allocation + from this specific index by setting VmaAllocationCreateInfo::memoryTypeBits `= 1u << index`. + +\section vk_amd_device_coherent_memory_more_information More information + +To learn more about this extension, see [VK_AMD_device_coherent_memory in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap44.html#VK_AMD_device_coherent_memory) + +Example use of this extension can be found in the code of the sample and test suite +accompanying this library. + + +\page enabling_buffer_device_address Enabling buffer device address + +Device extension VK_KHR_buffer_device_address +allow to fetch raw GPU pointer to a buffer and pass it for usage in a shader code. +It is promoted to core Vulkan 1.2. + +If you want to use this feature in connection with VMA, follow these steps: + +\section enabling_buffer_device_address_initialization Initialization + +1) (For Vulkan version < 1.2) Call `vkEnumerateDeviceExtensionProperties` for the physical device. +Check if the extension is supported - if returned array of `VkExtensionProperties` contains +"VK_KHR_buffer_device_address". + +2) Call `vkGetPhysicalDeviceFeatures2` for the physical device instead of old `vkGetPhysicalDeviceFeatures`. +Attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to `VkPhysicalDeviceFeatures2::pNext` to be returned. +Check if the device feature is really supported - check if `VkPhysicalDeviceBufferDeviceAddressFeatures*::bufferDeviceAddress` is true. + +3) (For Vulkan version < 1.2) While creating device with `vkCreateDevice`, enable this extension - add +"VK_KHR_buffer_device_address" to the list passed as `VkDeviceCreateInfo::ppEnabledExtensionNames`. + +4) While creating the device, also don't set `VkDeviceCreateInfo::pEnabledFeatures`. +Fill in `VkPhysicalDeviceFeatures2` structure instead and pass it as `VkDeviceCreateInfo::pNext`. +Enable this device feature - attach additional structure `VkPhysicalDeviceBufferDeviceAddressFeatures*` to +`VkPhysicalDeviceFeatures2::pNext` and set its member `bufferDeviceAddress` to `VK_TRUE`. + +5) While creating #VmaAllocator with vmaCreateAllocator() inform VMA that you +have enabled this feature - add #VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT +to VmaAllocatorCreateInfo::flags. + +\section enabling_buffer_device_address_usage Usage + +After following steps described above, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*` using VMA. +The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT*` to +allocated memory blocks wherever it might be needed. + +Please note that the library supports only `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*`. +The second part of this functionality related to "capture and replay" is not supported, +as it is intended for usage in debugging tools like RenderDoc, not in everyday Vulkan usage. + +\section enabling_buffer_device_address_more_information More information + +To learn more about this extension, see [VK_KHR_buffer_device_address in Vulkan specification](https://www.khronos.org/registry/vulkan/specs/1.2-extensions/html/chap46.html#VK_KHR_buffer_device_address) + +Example use of this extension can be found in the code of the sample and test suite +accompanying this library. + +\page general_considerations General considerations + +\section general_considerations_thread_safety Thread safety + +- The library has no global state, so separate #VmaAllocator objects can be used + independently. + There should be no need to create multiple such objects though - one per `VkDevice` is enough. +- By default, all calls to functions that take #VmaAllocator as first parameter + are safe to call from multiple threads simultaneously because they are + synchronized internally when needed. +- When the allocator is created with #VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT + flag, calls to functions that take such #VmaAllocator object must be + synchronized externally. +- Access to a #VmaAllocation object must be externally synchronized. For example, + you must not call vmaGetAllocationInfo() and vmaMapMemory() from different + threads at the same time if you pass the same #VmaAllocation object to these + functions. + +\section general_considerations_validation_layer_warnings Validation layer warnings + +When using this library, you can meet following types of warnings issued by +Vulkan validation layer. They don't necessarily indicate a bug, so you may need +to just ignore them. + +- *vkBindBufferMemory(): Binding memory to buffer 0xeb8e4 but vkGetBufferMemoryRequirements() has not been called on that buffer.* + - It happens when VK_KHR_dedicated_allocation extension is enabled. + `vkGetBufferMemoryRequirements2KHR` function is used instead, while validation layer seems to be unaware of it. +- *Mapping an image with layout VK_IMAGE_LAYOUT_DEPTH_STENCIL_ATTACHMENT_OPTIMAL can result in undefined behavior if this memory is used by the device. Only GENERAL or PREINITIALIZED should be used.* + - It happens when you map a buffer or image, because the library maps entire + `VkDeviceMemory` block, where different types of images and buffers may end + up together, especially on GPUs with unified memory like Intel. +- *Non-linear image 0xebc91 is aliased with linear buffer 0xeb8e4 which may indicate a bug.* + - It happens when you use lost allocations, and a new image or buffer is + created in place of an existing object that bacame lost. + - It may happen also when you use [defragmentation](@ref defragmentation). + +\section general_considerations_allocation_algorithm Allocation algorithm + +The library uses following algorithm for allocation, in order: + +-# Try to find free range of memory in existing blocks. +-# If failed, try to create a new block of `VkDeviceMemory`, with preferred block size. +-# If failed, try to create such block with size/2, size/4, size/8. +-# If failed and #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag was + specified, try to find space in existing blocks, possilby making some other + allocations lost. +-# If failed, try to allocate separate `VkDeviceMemory` for this allocation, + just like when you use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +-# If failed, choose other memory type that meets the requirements specified in + VmaAllocationCreateInfo and go to point 1. +-# If failed, return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + +\section general_considerations_features_not_supported Features not supported + +Features deliberately excluded from the scope of this library: + +- Data transfer. Uploading (straming) and downloading data of buffers and images + between CPU and GPU memory and related synchronization is responsibility of the user. + Defining some "texture" object that would automatically stream its data from a + staging copy in CPU memory to GPU memory would rather be a feature of another, + higher-level library implemented on top of VMA. +- Allocations for imported/exported external memory. They tend to require + explicit memory type index and dedicated allocation anyway, so they don't + interact with main features of this library. Such special purpose allocations + should be made manually, using `vkCreateBuffer()` and `vkAllocateMemory()`. +- Recreation of buffers and images. Although the library has functions for + buffer and image creation (vmaCreateBuffer(), vmaCreateImage()), you need to + recreate these objects yourself after defragmentation. That's because the big + structures `VkBufferCreateInfo`, `VkImageCreateInfo` are not stored in + #VmaAllocation object. +- Handling CPU memory allocation failures. When dynamically creating small C++ + objects in CPU memory (not Vulkan memory), allocation failures are not checked + and handled gracefully, because that would complicate code significantly and + is usually not needed in desktop PC applications anyway. + Success of an allocation is just checked with an assert. +- Code free of any compiler warnings. Maintaining the library to compile and + work correctly on so many different platforms is hard enough. Being free of + any warnings, on any version of any compiler, is simply not feasible. +- This is a C++ library with C interface. + Bindings or ports to any other programming languages are welcomed as external projects and + are not going to be included into this repository. + +*/ + +#if VMA_RECORDING_ENABLED + #include + #if defined(_WIN32) + #include + #else + #include + #include + #endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/* +Define this macro to 0/1 to disable/enable support for recording functionality, +available through VmaAllocatorCreateInfo::pRecordSettings. +*/ +#ifndef VMA_RECORDING_ENABLED + #define VMA_RECORDING_ENABLED 0 +#endif + +#ifndef NOMINMAX + #define NOMINMAX // For windows.h +#endif + +#if defined(__ANDROID__) && defined(VK_NO_PROTOTYPES) && VMA_STATIC_VULKAN_FUNCTIONS + extern PFN_vkGetInstanceProcAddr vkGetInstanceProcAddr; + extern PFN_vkGetDeviceProcAddr vkGetDeviceProcAddr; + extern PFN_vkGetPhysicalDeviceProperties vkGetPhysicalDeviceProperties; + extern PFN_vkGetPhysicalDeviceMemoryProperties vkGetPhysicalDeviceMemoryProperties; + extern PFN_vkAllocateMemory vkAllocateMemory; + extern PFN_vkFreeMemory vkFreeMemory; + extern PFN_vkMapMemory vkMapMemory; + extern PFN_vkUnmapMemory vkUnmapMemory; + extern PFN_vkFlushMappedMemoryRanges vkFlushMappedMemoryRanges; + extern PFN_vkInvalidateMappedMemoryRanges vkInvalidateMappedMemoryRanges; + extern PFN_vkBindBufferMemory vkBindBufferMemory; + extern PFN_vkBindImageMemory vkBindImageMemory; + extern PFN_vkGetBufferMemoryRequirements vkGetBufferMemoryRequirements; + extern PFN_vkGetImageMemoryRequirements vkGetImageMemoryRequirements; + extern PFN_vkCreateBuffer vkCreateBuffer; + extern PFN_vkDestroyBuffer vkDestroyBuffer; + extern PFN_vkCreateImage vkCreateImage; + extern PFN_vkDestroyImage vkDestroyImage; + extern PFN_vkCmdCopyBuffer vkCmdCopyBuffer; + #if VMA_VULKAN_VERSION >= 1001000 + extern PFN_vkGetBufferMemoryRequirements2 vkGetBufferMemoryRequirements2; + extern PFN_vkGetImageMemoryRequirements2 vkGetImageMemoryRequirements2; + extern PFN_vkBindBufferMemory2 vkBindBufferMemory2; + extern PFN_vkBindImageMemory2 vkBindImageMemory2; + extern PFN_vkGetPhysicalDeviceMemoryProperties2 vkGetPhysicalDeviceMemoryProperties2; + #endif // #if VMA_VULKAN_VERSION >= 1001000 +#endif // #if defined(__ANDROID__) && VMA_STATIC_VULKAN_FUNCTIONS && VK_NO_PROTOTYPES + +#ifndef VULKAN_H_ + #include +#endif + +// Define this macro to declare maximum supported Vulkan version in format AAABBBCCC, +// where AAA = major, BBB = minor, CCC = patch. +// If you want to use version > 1.0, it still needs to be enabled via VmaAllocatorCreateInfo::vulkanApiVersion. +#if !defined(VMA_VULKAN_VERSION) + #if defined(VK_VERSION_1_2) + #define VMA_VULKAN_VERSION 1002000 + #elif defined(VK_VERSION_1_1) + #define VMA_VULKAN_VERSION 1001000 + #else + #define VMA_VULKAN_VERSION 1000000 + #endif +#endif + +#if !defined(VMA_DEDICATED_ALLOCATION) + #if VK_KHR_get_memory_requirements2 && VK_KHR_dedicated_allocation + #define VMA_DEDICATED_ALLOCATION 1 + #else + #define VMA_DEDICATED_ALLOCATION 0 + #endif +#endif + +#if !defined(VMA_BIND_MEMORY2) + #if VK_KHR_bind_memory2 + #define VMA_BIND_MEMORY2 1 + #else + #define VMA_BIND_MEMORY2 0 + #endif +#endif + +#if !defined(VMA_MEMORY_BUDGET) + #if VK_EXT_memory_budget && (VK_KHR_get_physical_device_properties2 || VMA_VULKAN_VERSION >= 1001000) + #define VMA_MEMORY_BUDGET 1 + #else + #define VMA_MEMORY_BUDGET 0 + #endif +#endif + +// Defined to 1 when VK_KHR_buffer_device_address device extension or equivalent core Vulkan 1.2 feature is defined in its headers. +#if !defined(VMA_BUFFER_DEVICE_ADDRESS) + #if VK_KHR_buffer_device_address || VMA_VULKAN_VERSION >= 1002000 + #define VMA_BUFFER_DEVICE_ADDRESS 1 + #else + #define VMA_BUFFER_DEVICE_ADDRESS 0 + #endif +#endif + +// Define these macros to decorate all public functions with additional code, +// before and after returned type, appropriately. This may be useful for +// exporing the functions when compiling VMA as a separate library. Example: +// #define VMA_CALL_PRE __declspec(dllexport) +// #define VMA_CALL_POST __cdecl +#ifndef VMA_CALL_PRE + #define VMA_CALL_PRE +#endif +#ifndef VMA_CALL_POST + #define VMA_CALL_POST +#endif + +// Define this macro to decorate pointers with an attribute specifying the +// length of the array they point to if they are not null. +// +// The length may be one of +// - The name of another parameter in the argument list where the pointer is declared +// - The name of another member in the struct where the pointer is declared +// - The name of a member of a struct type, meaning the value of that member in +// the context of the call. For example +// VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount"), +// this means the number of memory heaps available in the device associated +// with the VmaAllocator being dealt with. +#ifndef VMA_LEN_IF_NOT_NULL + #define VMA_LEN_IF_NOT_NULL(len) +#endif + +// The VMA_NULLABLE macro is defined to be _Nullable when compiling with Clang. +// see: https://clang.llvm.org/docs/AttributeReference.html#nullable +#ifndef VMA_NULLABLE + #ifdef __clang__ + #define VMA_NULLABLE _Nullable + #else + #define VMA_NULLABLE + #endif +#endif + +// The VMA_NOT_NULL macro is defined to be _Nonnull when compiling with Clang. +// see: https://clang.llvm.org/docs/AttributeReference.html#nonnull +#ifndef VMA_NOT_NULL + #ifdef __clang__ + #define VMA_NOT_NULL _Nonnull + #else + #define VMA_NOT_NULL + #endif +#endif + +// If non-dispatchable handles are represented as pointers then we can give +// then nullability annotations +#ifndef VMA_NOT_NULL_NON_DISPATCHABLE + #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + #define VMA_NOT_NULL_NON_DISPATCHABLE VMA_NOT_NULL + #else + #define VMA_NOT_NULL_NON_DISPATCHABLE + #endif +#endif + +#ifndef VMA_NULLABLE_NON_DISPATCHABLE + #if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__) ) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + #define VMA_NULLABLE_NON_DISPATCHABLE VMA_NULLABLE + #else + #define VMA_NULLABLE_NON_DISPATCHABLE + #endif +#endif + +/** \struct VmaAllocator +\brief Represents main object of this library initialized. + +Fill structure #VmaAllocatorCreateInfo and call function vmaCreateAllocator() to create it. +Call function vmaDestroyAllocator() to destroy it. + +It is recommended to create just one object of this type per `VkDevice` object, +right after Vulkan is initialized and keep it alive until before Vulkan device is destroyed. +*/ +VK_DEFINE_HANDLE(VmaAllocator) + +/// Callback function called after successful vkAllocateMemory. +typedef void (VKAPI_PTR *PFN_vmaAllocateDeviceMemoryFunction)( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryType, + VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, + VkDeviceSize size, + void* VMA_NULLABLE pUserData); +/// Callback function called before vkFreeMemory. +typedef void (VKAPI_PTR *PFN_vmaFreeDeviceMemoryFunction)( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryType, + VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory, + VkDeviceSize size, + void* VMA_NULLABLE pUserData); + +/** \brief Set of callbacks that the library will call for `vkAllocateMemory` and `vkFreeMemory`. + +Provided for informative purpose, e.g. to gather statistics about number of +allocations or total amount of memory allocated in Vulkan. + +Used in VmaAllocatorCreateInfo::pDeviceMemoryCallbacks. +*/ +typedef struct VmaDeviceMemoryCallbacks { + /// Optional, can be null. + PFN_vmaAllocateDeviceMemoryFunction VMA_NULLABLE pfnAllocate; + /// Optional, can be null. + PFN_vmaFreeDeviceMemoryFunction VMA_NULLABLE pfnFree; + /// Optional, can be null. + void* VMA_NULLABLE pUserData; +} VmaDeviceMemoryCallbacks; + +/// Flags for created #VmaAllocator. +typedef enum VmaAllocatorCreateFlagBits { + /** \brief Allocator and all objects created from it will not be synchronized internally, so you must guarantee they are used from only one thread at a time or synchronized externally by you. + + Using this flag may increase performance because internal mutexes are not used. + */ + VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT = 0x00000001, + /** \brief Enables usage of VK_KHR_dedicated_allocation extension. + + The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. + When it's `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. + + Using this extenion will automatically allocate dedicated blocks of memory for + some buffers and images instead of suballocating place for them out of bigger + memory blocks (as if you explicitly used #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT + flag) when it is recommended by the driver. It may improve performance on some + GPUs. + + You may set this flag only if you found out that following device extensions are + supported, you enabled them while creating Vulkan device passed as + VmaAllocatorCreateInfo::device, and you want them to be used internally by this + library: + + - VK_KHR_get_memory_requirements2 (device extension) + - VK_KHR_dedicated_allocation (device extension) + + When this flag is set, you can experience following warnings reported by Vulkan + validation layer. You can ignore them. + + > vkBindBufferMemory(): Binding memory to buffer 0x2d but vkGetBufferMemoryRequirements() has not been called on that buffer. + */ + VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT = 0x00000002, + /** + Enables usage of VK_KHR_bind_memory2 extension. + + The flag works only if VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_0`. + When it's `VK_API_VERSION_1_1`, the flag is ignored because the extension has been promoted to Vulkan 1.1. + + You may set this flag only if you found out that this device extension is supported, + you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + and you want it to be used internally by this library. + + The extension provides functions `vkBindBufferMemory2KHR` and `vkBindImageMemory2KHR`, + which allow to pass a chain of `pNext` structures while binding. + This flag is required if you use `pNext` parameter in vmaBindBufferMemory2() or vmaBindImageMemory2(). + */ + VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT = 0x00000004, + /** + Enables usage of VK_EXT_memory_budget extension. + + You may set this flag only if you found out that this device extension is supported, + you enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + and you want it to be used internally by this library, along with another instance extension + VK_KHR_get_physical_device_properties2, which is required by it (or Vulkan 1.1, where this extension is promoted). + + The extension provides query for current memory usage and budget, which will probably + be more accurate than an estimation used by the library otherwise. + */ + VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT = 0x00000008, + /** + Enables usage of VK_AMD_device_coherent_memory extension. + + You may set this flag only if you: + + - found out that this device extension is supported and enabled it while creating Vulkan device passed as VmaAllocatorCreateInfo::device, + - checked that `VkPhysicalDeviceCoherentMemoryFeaturesAMD::deviceCoherentMemory` is true and set it while creating the Vulkan device, + - want it to be used internally by this library. + + The extension and accompanying device feature provide access to memory types with + `VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD` and `VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD` flags. + They are useful mostly for writing breadcrumb markers - a common method for debugging GPU crash/hang/TDR. + + When the extension is not enabled, such memory types are still enumerated, but their usage is illegal. + To protect from this error, if you don't create the allocator with this flag, it will refuse to allocate any memory or create a custom pool in such memory type, + returning `VK_ERROR_FEATURE_NOT_PRESENT`. + */ + VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT = 0x00000010, + /** + Enables usage of "buffer device address" feature, which allows you to use function + `vkGetBufferDeviceAddress*` to get raw GPU pointer to a buffer and pass it for usage inside a shader. + + You may set this flag only if you: + + 1. (For Vulkan version < 1.2) Found as available and enabled device extension + VK_KHR_buffer_device_address. + This extension is promoted to core Vulkan 1.2. + 2. Found as available and enabled device feature `VkPhysicalDeviceBufferDeviceAddressFeatures*::bufferDeviceAddress`. + + When this flag is set, you can create buffers with `VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT*` using VMA. + The library automatically adds `VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT*` to + allocated memory blocks wherever it might be needed. + + For more information, see documentation chapter \ref enabling_buffer_device_address. + */ + VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT = 0x00000020, + + VMA_ALLOCATOR_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaAllocatorCreateFlagBits; +typedef VkFlags VmaAllocatorCreateFlags; + +/** \brief Pointers to some Vulkan functions - a subset used by the library. + +Used in VmaAllocatorCreateInfo::pVulkanFunctions. +*/ +typedef struct VmaVulkanFunctions { + PFN_vkGetPhysicalDeviceProperties VMA_NULLABLE vkGetPhysicalDeviceProperties; + PFN_vkGetPhysicalDeviceMemoryProperties VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties; + PFN_vkAllocateMemory VMA_NULLABLE vkAllocateMemory; + PFN_vkFreeMemory VMA_NULLABLE vkFreeMemory; + PFN_vkMapMemory VMA_NULLABLE vkMapMemory; + PFN_vkUnmapMemory VMA_NULLABLE vkUnmapMemory; + PFN_vkFlushMappedMemoryRanges VMA_NULLABLE vkFlushMappedMemoryRanges; + PFN_vkInvalidateMappedMemoryRanges VMA_NULLABLE vkInvalidateMappedMemoryRanges; + PFN_vkBindBufferMemory VMA_NULLABLE vkBindBufferMemory; + PFN_vkBindImageMemory VMA_NULLABLE vkBindImageMemory; + PFN_vkGetBufferMemoryRequirements VMA_NULLABLE vkGetBufferMemoryRequirements; + PFN_vkGetImageMemoryRequirements VMA_NULLABLE vkGetImageMemoryRequirements; + PFN_vkCreateBuffer VMA_NULLABLE vkCreateBuffer; + PFN_vkDestroyBuffer VMA_NULLABLE vkDestroyBuffer; + PFN_vkCreateImage VMA_NULLABLE vkCreateImage; + PFN_vkDestroyImage VMA_NULLABLE vkDestroyImage; + PFN_vkCmdCopyBuffer VMA_NULLABLE vkCmdCopyBuffer; +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + PFN_vkGetBufferMemoryRequirements2KHR VMA_NULLABLE vkGetBufferMemoryRequirements2KHR; + PFN_vkGetImageMemoryRequirements2KHR VMA_NULLABLE vkGetImageMemoryRequirements2KHR; +#endif +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + PFN_vkBindBufferMemory2KHR VMA_NULLABLE vkBindBufferMemory2KHR; + PFN_vkBindImageMemory2KHR VMA_NULLABLE vkBindImageMemory2KHR; +#endif +#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 + PFN_vkGetPhysicalDeviceMemoryProperties2KHR VMA_NULLABLE vkGetPhysicalDeviceMemoryProperties2KHR; +#endif +} VmaVulkanFunctions; + +/// Flags to be used in VmaRecordSettings::flags. +typedef enum VmaRecordFlagBits { + /** \brief Enables flush after recording every function call. + + Enable it if you expect your application to crash, which may leave recording file truncated. + It may degrade performance though. + */ + VMA_RECORD_FLUSH_AFTER_CALL_BIT = 0x00000001, + + VMA_RECORD_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaRecordFlagBits; +typedef VkFlags VmaRecordFlags; + +/// Parameters for recording calls to VMA functions. To be used in VmaAllocatorCreateInfo::pRecordSettings. +typedef struct VmaRecordSettings +{ + /// Flags for recording. Use #VmaRecordFlagBits enum. + VmaRecordFlags flags; + /** \brief Path to the file that should be written by the recording. + + Suggested extension: "csv". + If the file already exists, it will be overwritten. + It will be opened for the whole time #VmaAllocator object is alive. + If opening this file fails, creation of the whole allocator object fails. + */ + const char* VMA_NOT_NULL pFilePath; +} VmaRecordSettings; + +/// Description of a Allocator to be created. +typedef struct VmaAllocatorCreateInfo +{ + /// Flags for created allocator. Use #VmaAllocatorCreateFlagBits enum. + VmaAllocatorCreateFlags flags; + /// Vulkan physical device. + /** It must be valid throughout whole lifetime of created allocator. */ + VkPhysicalDevice VMA_NOT_NULL physicalDevice; + /// Vulkan device. + /** It must be valid throughout whole lifetime of created allocator. */ + VkDevice VMA_NOT_NULL device; + /// Preferred size of a single `VkDeviceMemory` block to be allocated from large heaps > 1 GiB. Optional. + /** Set to 0 to use default, which is currently 256 MiB. */ + VkDeviceSize preferredLargeHeapBlockSize; + /// Custom CPU memory allocation callbacks. Optional. + /** Optional, can be null. When specified, will also be used for all CPU-side memory allocations. */ + const VkAllocationCallbacks* VMA_NULLABLE pAllocationCallbacks; + /// Informative callbacks for `vkAllocateMemory`, `vkFreeMemory`. Optional. + /** Optional, can be null. */ + const VmaDeviceMemoryCallbacks* VMA_NULLABLE pDeviceMemoryCallbacks; + /** \brief Maximum number of additional frames that are in use at the same time as current frame. + + This value is used only when you make allocations with + VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocation cannot become + lost if allocation.lastUseFrameIndex >= allocator.currentFrameIndex - frameInUseCount. + + For example, if you double-buffer your command buffers, so resources used for + rendering in previous frame may still be in use by the GPU at the moment you + allocate resources needed for the current frame, set this value to 1. + + If you want to allow any allocations other than used in the current frame to + become lost, set this value to 0. + */ + uint32_t frameInUseCount; + /** \brief Either null or a pointer to an array of limits on maximum number of bytes that can be allocated out of particular Vulkan memory heap. + + If not NULL, it must be a pointer to an array of + `VkPhysicalDeviceMemoryProperties::memoryHeapCount` elements, defining limit on + maximum number of bytes that can be allocated out of particular Vulkan memory + heap. + + Any of the elements may be equal to `VK_WHOLE_SIZE`, which means no limit on that + heap. This is also the default in case of `pHeapSizeLimit` = NULL. + + If there is a limit defined for a heap: + + - If user tries to allocate more memory from that heap using this allocator, + the allocation fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + - If the limit is smaller than heap size reported in `VkMemoryHeap::size`, the + value of this limit will be reported instead when using vmaGetMemoryProperties(). + + Warning! Using this feature may not be equivalent to installing a GPU with + smaller amount of memory, because graphics driver doesn't necessary fail new + allocations with `VK_ERROR_OUT_OF_DEVICE_MEMORY` result when memory capacity is + exceeded. It may return success and just silently migrate some device memory + blocks to system RAM. This driver behavior can also be controlled using + VK_AMD_memory_overallocation_behavior extension. + */ + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL("VkPhysicalDeviceMemoryProperties::memoryHeapCount") pHeapSizeLimit; + + /** \brief Pointers to Vulkan functions. Can be null. + + For details see [Pointers to Vulkan functions](@ref config_Vulkan_functions). + */ + const VmaVulkanFunctions* VMA_NULLABLE pVulkanFunctions; + /** \brief Parameters for recording of VMA calls. Can be null. + + If not null, it enables recording of calls to VMA functions to a file. + If support for recording is not enabled using `VMA_RECORDING_ENABLED` macro, + creation of the allocator object fails with `VK_ERROR_FEATURE_NOT_PRESENT`. + */ + const VmaRecordSettings* VMA_NULLABLE pRecordSettings; + /** \brief Handle to Vulkan instance object. + + Starting from version 3.0.0 this member is no longer optional, it must be set! + */ + VkInstance VMA_NOT_NULL instance; + /** \brief Optional. The highest version of Vulkan that the application is designed to use. + + It must be a value in the format as created by macro `VK_MAKE_VERSION` or a constant like: `VK_API_VERSION_1_1`, `VK_API_VERSION_1_0`. + The patch version number specified is ignored. Only the major and minor versions are considered. + It must be less or equal (preferably equal) to value as passed to `vkCreateInstance` as `VkApplicationInfo::apiVersion`. + Only versions 1.0 and 1.1 are supported by the current implementation. + Leaving it initialized to zero is equivalent to `VK_API_VERSION_1_0`. + */ + uint32_t vulkanApiVersion; +} VmaAllocatorCreateInfo; + +/// Creates Allocator object. +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( + const VmaAllocatorCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocator VMA_NULLABLE * VMA_NOT_NULL pAllocator); + +/// Destroys allocator object. +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( + VmaAllocator VMA_NULLABLE allocator); + +/** \brief Information about existing #VmaAllocator object. +*/ +typedef struct VmaAllocatorInfo +{ + /** \brief Handle to Vulkan instance object. + + This is the same value as has been passed through VmaAllocatorCreateInfo::instance. + */ + VkInstance VMA_NOT_NULL instance; + /** \brief Handle to Vulkan physical device object. + + This is the same value as has been passed through VmaAllocatorCreateInfo::physicalDevice. + */ + VkPhysicalDevice VMA_NOT_NULL physicalDevice; + /** \brief Handle to Vulkan device object. + + This is the same value as has been passed through VmaAllocatorCreateInfo::device. + */ + VkDevice VMA_NOT_NULL device; +} VmaAllocatorInfo; + +/** \brief Returns information about existing #VmaAllocator object - handle to Vulkan device etc. + +It might be useful if you want to keep just the #VmaAllocator handle and fetch other required handles to +`VkPhysicalDevice`, `VkDevice` etc. every time using this function. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo(VmaAllocator VMA_NOT_NULL allocator, VmaAllocatorInfo* VMA_NOT_NULL pAllocatorInfo); + +/** +PhysicalDeviceProperties are fetched from physicalDevice by the allocator. +You can access it here, without fetching it again on your own. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( + VmaAllocator VMA_NOT_NULL allocator, + const VkPhysicalDeviceProperties* VMA_NULLABLE * VMA_NOT_NULL ppPhysicalDeviceProperties); + +/** +PhysicalDeviceMemoryProperties are fetched from physicalDevice by the allocator. +You can access it here, without fetching it again on your own. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( + VmaAllocator VMA_NOT_NULL allocator, + const VkPhysicalDeviceMemoryProperties* VMA_NULLABLE * VMA_NOT_NULL ppPhysicalDeviceMemoryProperties); + +/** +\brief Given Memory Type Index, returns Property Flags of this memory type. + +This is just a convenience function. Same information can be obtained using +vmaGetMemoryProperties(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryTypeIndex, + VkMemoryPropertyFlags* VMA_NOT_NULL pFlags); + +/** \brief Sets index of the current frame. + +This function must be used if you make allocations with +#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT and +#VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flags to inform the allocator +when a new frame begins. Allocations queried using vmaGetAllocationInfo() cannot +become lost in the current frame. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t frameIndex); + +/** \brief Calculated statistics of memory usage in entire allocator. +*/ +typedef struct VmaStatInfo +{ + /// Number of `VkDeviceMemory` Vulkan memory blocks allocated. + uint32_t blockCount; + /// Number of #VmaAllocation allocation objects allocated. + uint32_t allocationCount; + /// Number of free ranges of memory between allocations. + uint32_t unusedRangeCount; + /// Total number of bytes occupied by all allocations. + VkDeviceSize usedBytes; + /// Total number of bytes occupied by unused ranges. + VkDeviceSize unusedBytes; + VkDeviceSize allocationSizeMin, allocationSizeAvg, allocationSizeMax; + VkDeviceSize unusedRangeSizeMin, unusedRangeSizeAvg, unusedRangeSizeMax; +} VmaStatInfo; + +/// General statistics from current state of Allocator. +typedef struct VmaStats +{ + VmaStatInfo memoryType[VK_MAX_MEMORY_TYPES]; + VmaStatInfo memoryHeap[VK_MAX_MEMORY_HEAPS]; + VmaStatInfo total; +} VmaStats; + +/** \brief Retrieves statistics from current state of the Allocator. + +This function is called "calculate" not "get" because it has to traverse all +internal data structures, so it may be quite slow. For faster but more brief statistics +suitable to be called every frame or every allocation, use vmaGetBudget(). + +Note that when using allocator from multiple threads, returned information may immediately +become outdated. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStats( + VmaAllocator VMA_NOT_NULL allocator, + VmaStats* VMA_NOT_NULL pStats); + +/** \brief Statistics of current memory usage and available budget, in bytes, for specific memory heap. +*/ +typedef struct VmaBudget +{ + /** \brief Sum size of all `VkDeviceMemory` blocks allocated from particular heap, in bytes. + */ + VkDeviceSize blockBytes; + + /** \brief Sum size of all allocations created in particular heap, in bytes. + + Usually less or equal than `blockBytes`. + Difference `blockBytes - allocationBytes` is the amount of memory allocated but unused - + available for new allocations or wasted due to fragmentation. + + It might be greater than `blockBytes` if there are some allocations in lost state, as they account + to this value as well. + */ + VkDeviceSize allocationBytes; + + /** \brief Estimated current memory usage of the program, in bytes. + + Fetched from system using `VK_EXT_memory_budget` extension if enabled. + + It might be different than `blockBytes` (usually higher) due to additional implicit objects + also occupying the memory, like swapchain, pipelines, descriptor heaps, command buffers, or + `VkDeviceMemory` blocks allocated outside of this library, if any. + */ + VkDeviceSize usage; + + /** \brief Estimated amount of memory available to the program, in bytes. + + Fetched from system using `VK_EXT_memory_budget` extension if enabled. + + It might be different (most probably smaller) than `VkMemoryHeap::size[heapIndex]` due to factors + external to the program, like other programs also consuming system resources. + Difference `budget - usage` is the amount of additional memory that can probably + be allocated without problems. Exceeding the budget may result in various problems. + */ + VkDeviceSize budget; +} VmaBudget; + +/** \brief Retrieves information about current memory budget for all memory heaps. + +\param[out] pBudget Must point to array with number of elements at least equal to number of memory heaps in physical device used. + +This function is called "get" not "calculate" because it is very fast, suitable to be called +every frame or every allocation. For more detailed statistics use vmaCalculateStats(). + +Note that when using allocator from multiple threads, returned information may immediately +become outdated. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetBudget( + VmaAllocator VMA_NOT_NULL allocator, + VmaBudget* VMA_NOT_NULL pBudget); + +#ifndef VMA_STATS_STRING_ENABLED +#define VMA_STATS_STRING_ENABLED 1 +#endif + +#if VMA_STATS_STRING_ENABLED + +/// Builds and returns statistics as string in JSON format. +/** @param[out] ppStatsString Must be freed using vmaFreeStatsString() function. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( + VmaAllocator VMA_NOT_NULL allocator, + char* VMA_NULLABLE * VMA_NOT_NULL ppStatsString, + VkBool32 detailedMap); + +VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( + VmaAllocator VMA_NOT_NULL allocator, + char* VMA_NULLABLE pStatsString); + +#endif // #if VMA_STATS_STRING_ENABLED + +/** \struct VmaPool +\brief Represents custom memory pool + +Fill structure VmaPoolCreateInfo and call function vmaCreatePool() to create it. +Call function vmaDestroyPool() to destroy it. + +For more information see [Custom memory pools](@ref choosing_memory_type_custom_memory_pools). +*/ +VK_DEFINE_HANDLE(VmaPool) + +typedef enum VmaMemoryUsage +{ + /** No intended memory usage specified. + Use other members of VmaAllocationCreateInfo to specify your requirements. + */ + VMA_MEMORY_USAGE_UNKNOWN = 0, + /** Memory will be used on device only, so fast access from the device is preferred. + It usually means device-local GPU (video) memory. + No need to be mappable on host. + It is roughly equivalent of `D3D12_HEAP_TYPE_DEFAULT`. + + Usage: + + - Resources written and read by device, e.g. images used as attachments. + - Resources transferred from host once (immutable) or infrequently and read by + device multiple times, e.g. textures to be sampled, vertex buffers, uniform + (constant) buffers, and majority of other types of resources used on GPU. + + Allocation may still end up in `HOST_VISIBLE` memory on some implementations. + In such case, you are free to map it. + You can use #VMA_ALLOCATION_CREATE_MAPPED_BIT with this usage type. + */ + VMA_MEMORY_USAGE_GPU_ONLY = 1, + /** Memory will be mappable on host. + It usually means CPU (system) memory. + Guarantees to be `HOST_VISIBLE` and `HOST_COHERENT`. + CPU access is typically uncached. Writes may be write-combined. + Resources created in this pool may still be accessible to the device, but access to them can be slow. + It is roughly equivalent of `D3D12_HEAP_TYPE_UPLOAD`. + + Usage: Staging copy of resources used as transfer source. + */ + VMA_MEMORY_USAGE_CPU_ONLY = 2, + /** + Memory that is both mappable on host (guarantees to be `HOST_VISIBLE`) and preferably fast to access by GPU. + CPU access is typically uncached. Writes may be write-combined. + + Usage: Resources written frequently by host (dynamic), read by device. E.g. textures (with LINEAR layout), vertex buffers, uniform buffers updated every frame or every draw call. + */ + VMA_MEMORY_USAGE_CPU_TO_GPU = 3, + /** Memory mappable on host (guarantees to be `HOST_VISIBLE`) and cached. + It is roughly equivalent of `D3D12_HEAP_TYPE_READBACK`. + + Usage: + + - Resources written by device, read by host - results of some computations, e.g. screen capture, average scene luminance for HDR tone mapping. + - Any resources read or accessed randomly on host, e.g. CPU-side copy of vertex buffer used as source of transfer, but also used for collision detection. + */ + VMA_MEMORY_USAGE_GPU_TO_CPU = 4, + /** CPU memory - memory that is preferably not `DEVICE_LOCAL`, but also not guaranteed to be `HOST_VISIBLE`. + + Usage: Staging copy of resources moved from GPU memory to CPU memory as part + of custom paging/residency mechanism, to be moved back to GPU memory when needed. + */ + VMA_MEMORY_USAGE_CPU_COPY = 5, + /** Lazily allocated GPU memory having `VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT`. + Exists mostly on mobile platforms. Using it on desktop PC or other GPUs with no such memory type present will fail the allocation. + + Usage: Memory for transient attachment images (color attachments, depth attachments etc.), created with `VK_IMAGE_USAGE_TRANSIENT_ATTACHMENT_BIT`. + + Allocations with this usage are always created as dedicated - it implies #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. + */ + VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED = 6, + + VMA_MEMORY_USAGE_MAX_ENUM = 0x7FFFFFFF +} VmaMemoryUsage; + +/// Flags to be passed as VmaAllocationCreateInfo::flags. +typedef enum VmaAllocationCreateFlagBits { + /** \brief Set this flag if the allocation should have its own memory block. + + Use it for special, big resources, like fullscreen images used as attachments. + + You should not use this flag if VmaAllocationCreateInfo::pool is not null. + */ + VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT = 0x00000001, + + /** \brief Set this flag to only try to allocate from existing `VkDeviceMemory` blocks and never create new such block. + + If new allocation cannot be placed in any of the existing blocks, allocation + fails with `VK_ERROR_OUT_OF_DEVICE_MEMORY` error. + + You should not use #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT and + #VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT at the same time. It makes no sense. + + If VmaAllocationCreateInfo::pool is not null, this flag is implied and ignored. */ + VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT = 0x00000002, + /** \brief Set this flag to use a memory that will be persistently mapped and retrieve pointer to it. + + Pointer to mapped memory will be returned through VmaAllocationInfo::pMappedData. + + Is it valid to use this flag for allocation made from memory type that is not + `HOST_VISIBLE`. This flag is then ignored and memory is not mapped. This is + useful if you need an allocation that is efficient to use on GPU + (`DEVICE_LOCAL`) and still want to map it directly if possible on platforms that + support it (e.g. Intel GPU). + + You should not use this flag together with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT. + */ + VMA_ALLOCATION_CREATE_MAPPED_BIT = 0x00000004, + /** Allocation created with this flag can become lost as a result of another + allocation with #VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT flag, so you + must check it before use. + + To check if allocation is not lost, call vmaGetAllocationInfo() and check if + VmaAllocationInfo::deviceMemory is not `VK_NULL_HANDLE`. + + For details about supporting lost allocations, see Lost Allocations + chapter of User Guide on Main Page. + + You should not use this flag together with #VMA_ALLOCATION_CREATE_MAPPED_BIT. + */ + VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT = 0x00000008, + /** While creating allocation using this flag, other allocations that were + created with flag #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT can become lost. + + For details about supporting lost allocations, see Lost Allocations + chapter of User Guide on Main Page. + */ + VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT = 0x00000010, + /** Set this flag to treat VmaAllocationCreateInfo::pUserData as pointer to a + null-terminated string. Instead of copying pointer value, a local copy of the + string is made and stored in allocation's `pUserData`. The string is automatically + freed together with the allocation. It is also used in vmaBuildStatsString(). + */ + VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT = 0x00000020, + /** Allocation will be created from upper stack in a double stack pool. + + This flag is only allowed for custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT flag. + */ + VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT = 0x00000040, + /** Create both buffer/image and allocation, but don't bind them together. + It is useful when you want to bind yourself to do some more advanced binding, e.g. using some extensions. + The flag is meaningful only with functions that bind by default: vmaCreateBuffer(), vmaCreateImage(). + Otherwise it is ignored. + */ + VMA_ALLOCATION_CREATE_DONT_BIND_BIT = 0x00000080, + /** Create allocation only if additional device memory required for it, if any, won't exceed + memory budget. Otherwise return `VK_ERROR_OUT_OF_DEVICE_MEMORY`. + */ + VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT = 0x00000100, + + /** Allocation strategy that chooses smallest possible free range for the + allocation. + */ + VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT = 0x00010000, + /** Allocation strategy that chooses biggest possible free range for the + allocation. + */ + VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT = 0x00020000, + /** Allocation strategy that chooses first suitable free range for the + allocation. + + "First" doesn't necessarily means the one with smallest offset in memory, + but rather the one that is easiest and fastest to find. + */ + VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT = 0x00040000, + + /** Allocation strategy that tries to minimize memory usage. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT = VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT, + /** Allocation strategy that tries to minimize allocation time. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT = VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT, + /** Allocation strategy that tries to minimize memory fragmentation. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT = VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT, + + /** A bit mask to extract only `STRATEGY` bits from entire set of flags. + */ + VMA_ALLOCATION_CREATE_STRATEGY_MASK = + VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT | + VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT | + VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT, + + VMA_ALLOCATION_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaAllocationCreateFlagBits; +typedef VkFlags VmaAllocationCreateFlags; + +typedef struct VmaAllocationCreateInfo +{ + /// Use #VmaAllocationCreateFlagBits enum. + VmaAllocationCreateFlags flags; + /** \brief Intended usage of memory. + + You can leave #VMA_MEMORY_USAGE_UNKNOWN if you specify memory requirements in other way. \n + If `pool` is not null, this member is ignored. + */ + VmaMemoryUsage usage; + /** \brief Flags that must be set in a Memory Type chosen for an allocation. + + Leave 0 if you specify memory requirements in other way. \n + If `pool` is not null, this member is ignored.*/ + VkMemoryPropertyFlags requiredFlags; + /** \brief Flags that preferably should be set in a memory type chosen for an allocation. + + Set to 0 if no additional flags are prefered. \n + If `pool` is not null, this member is ignored. */ + VkMemoryPropertyFlags preferredFlags; + /** \brief Bitmask containing one bit set for every memory type acceptable for this allocation. + + Value 0 is equivalent to `UINT32_MAX` - it means any memory type is accepted if + it meets other requirements specified by this structure, with no further + restrictions on memory type index. \n + If `pool` is not null, this member is ignored. + */ + uint32_t memoryTypeBits; + /** \brief Pool that this allocation should be created in. + + Leave `VK_NULL_HANDLE` to allocate from default pool. If not null, members: + `usage`, `requiredFlags`, `preferredFlags`, `memoryTypeBits` are ignored. + */ + VmaPool VMA_NULLABLE pool; + /** \brief Custom general-purpose pointer that will be stored in #VmaAllocation, can be read as VmaAllocationInfo::pUserData and changed using vmaSetAllocationUserData(). + + If #VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT is used, it must be either + null or pointer to a null-terminated string. The string will be then copied to + internal buffer, so it doesn't need to be valid after allocation call. + */ + void* VMA_NULLABLE pUserData; +} VmaAllocationCreateInfo; + +/** +\brief Helps to find memoryTypeIndex, given memoryTypeBits and VmaAllocationCreateInfo. + +This algorithm tries to find a memory type that: + +- Is allowed by memoryTypeBits. +- Contains all the flags from pAllocationCreateInfo->requiredFlags. +- Matches intended usage. +- Has as many flags from pAllocationCreateInfo->preferredFlags as possible. + +\return Returns VK_ERROR_FEATURE_NOT_PRESENT if not found. Receiving such result +from this function or any other allocating function probably means that your +device doesn't support any memory type with requested features for the specific +type of resource you want to use it for. Please check parameters of your +resource, like image layout (OPTIMAL versus LINEAR) or mip level count. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); + +/** +\brief Helps to find memoryTypeIndex, given VkBufferCreateInfo and VmaAllocationCreateInfo. + +It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. +It internally creates a temporary, dummy buffer that never has memory bound. +It is just a convenience function, equivalent to calling: + +- `vkCreateBuffer` +- `vkGetBufferMemoryRequirements` +- `vmaFindMemoryTypeIndex` +- `vkDestroyBuffer` +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( + VmaAllocator VMA_NOT_NULL allocator, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); + +/** +\brief Helps to find memoryTypeIndex, given VkImageCreateInfo and VmaAllocationCreateInfo. + +It can be useful e.g. to determine value to be used as VmaPoolCreateInfo::memoryTypeIndex. +It internally creates a temporary, dummy image that never has memory bound. +It is just a convenience function, equivalent to calling: + +- `vkCreateImage` +- `vkGetImageMemoryRequirements` +- `vmaFindMemoryTypeIndex` +- `vkDestroyImage` +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( + VmaAllocator VMA_NOT_NULL allocator, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + uint32_t* VMA_NOT_NULL pMemoryTypeIndex); + +/// Flags to be passed as VmaPoolCreateInfo::flags. +typedef enum VmaPoolCreateFlagBits { + /** \brief Use this flag if you always allocate only buffers and linear images or only optimal images out of this pool and so Buffer-Image Granularity can be ignored. + + This is an optional optimization flag. + + If you always allocate using vmaCreateBuffer(), vmaCreateImage(), + vmaAllocateMemoryForBuffer(), then you don't need to use it because allocator + knows exact type of your allocations so it can handle Buffer-Image Granularity + in the optimal way. + + If you also allocate using vmaAllocateMemoryForImage() or vmaAllocateMemory(), + exact type of such allocations is not known, so allocator must be conservative + in handling Buffer-Image Granularity, which can lead to suboptimal allocation + (wasted memory). In that case, if you can make sure you always allocate only + buffers and linear images or only optimal images out of this pool, use this flag + to make allocator disregard Buffer-Image Granularity and so make allocations + faster and more optimal. + */ + VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT = 0x00000002, + + /** \brief Enables alternative, linear allocation algorithm in this pool. + + Specify this flag to enable linear allocation algorithm, which always creates + new allocations after last one and doesn't reuse space from allocations freed in + between. It trades memory consumption for simplified algorithm and data + structure, which has better performance and uses less memory for metadata. + + By using this flag, you can achieve behavior of free-at-once, stack, + ring buffer, and double stack. For details, see documentation chapter + \ref linear_algorithm. + + When using this flag, you must specify VmaPoolCreateInfo::maxBlockCount == 1 (or 0 for default). + + For more details, see [Linear allocation algorithm](@ref linear_algorithm). + */ + VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT = 0x00000004, + + /** \brief Enables alternative, buddy allocation algorithm in this pool. + + It operates on a tree of blocks, each having size that is a power of two and + a half of its parent's size. Comparing to default algorithm, this one provides + faster allocation and deallocation and decreased external fragmentation, + at the expense of more memory wasted (internal fragmentation). + + For more details, see [Buddy allocation algorithm](@ref buddy_algorithm). + */ + VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT = 0x00000008, + + /** Bit mask to extract only `ALGORITHM` bits from entire set of flags. + */ + VMA_POOL_CREATE_ALGORITHM_MASK = + VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT | + VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT, + + VMA_POOL_CREATE_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaPoolCreateFlagBits; +typedef VkFlags VmaPoolCreateFlags; + +/** \brief Describes parameter of created #VmaPool. +*/ +typedef struct VmaPoolCreateInfo { + /** \brief Vulkan memory type index to allocate this pool from. + */ + uint32_t memoryTypeIndex; + /** \brief Use combination of #VmaPoolCreateFlagBits. + */ + VmaPoolCreateFlags flags; + /** \brief Size of a single `VkDeviceMemory` block to be allocated as part of this pool, in bytes. Optional. + + Specify nonzero to set explicit, constant size of memory blocks used by this + pool. + + Leave 0 to use default and let the library manage block sizes automatically. + Sizes of particular blocks may vary. + */ + VkDeviceSize blockSize; + /** \brief Minimum number of blocks to be always allocated in this pool, even if they stay empty. + + Set to 0 to have no preallocated blocks and allow the pool be completely empty. + */ + size_t minBlockCount; + /** \brief Maximum number of blocks that can be allocated in this pool. Optional. + + Set to 0 to use default, which is `SIZE_MAX`, which means no limit. + + Set to same value as VmaPoolCreateInfo::minBlockCount to have fixed amount of memory allocated + throughout whole lifetime of this pool. + */ + size_t maxBlockCount; + /** \brief Maximum number of additional frames that are in use at the same time as current frame. + + This value is used only when you make allocations with + #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocation cannot become + lost if allocation.lastUseFrameIndex >= allocator.currentFrameIndex - frameInUseCount. + + For example, if you double-buffer your command buffers, so resources used for + rendering in previous frame may still be in use by the GPU at the moment you + allocate resources needed for the current frame, set this value to 1. + + If you want to allow any allocations other than used in the current frame to + become lost, set this value to 0. + */ + uint32_t frameInUseCount; +} VmaPoolCreateInfo; + +/** \brief Describes parameter of existing #VmaPool. +*/ +typedef struct VmaPoolStats { + /** \brief Total amount of `VkDeviceMemory` allocated from Vulkan for this pool, in bytes. + */ + VkDeviceSize size; + /** \brief Total number of bytes in the pool not used by any #VmaAllocation. + */ + VkDeviceSize unusedSize; + /** \brief Number of #VmaAllocation objects created from this pool that were not destroyed or lost. + */ + size_t allocationCount; + /** \brief Number of continuous memory ranges in the pool not used by any #VmaAllocation. + */ + size_t unusedRangeCount; + /** \brief Size of the largest continuous free memory region available for new allocation. + + Making a new allocation of that size is not guaranteed to succeed because of + possible additional margin required to respect alignment and buffer/image + granularity. + */ + VkDeviceSize unusedRangeSizeMax; + /** \brief Number of `VkDeviceMemory` blocks allocated for this pool. + */ + size_t blockCount; +} VmaPoolStats; + +/** \brief Allocates Vulkan device memory and creates #VmaPool object. + +@param allocator Allocator object. +@param pCreateInfo Parameters of pool to create. +@param[out] pPool Handle to created pool. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( + VmaAllocator VMA_NOT_NULL allocator, + const VmaPoolCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaPool VMA_NULLABLE * VMA_NOT_NULL pPool); + +/** \brief Destroys #VmaPool object and frees Vulkan device memory. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NULLABLE pool); + +/** \brief Retrieves statistics of existing #VmaPool object. + +@param allocator Allocator object. +@param pool Pool object. +@param[out] pPoolStats Statistics of specified pool. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStats( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + VmaPoolStats* VMA_NOT_NULL pPoolStats); + +/** \brief Marks all allocations in given pool as lost if they are not used in current frame or VmaPoolCreateInfo::frameInUseCount back from now. + +@param allocator Allocator object. +@param pool Pool. +@param[out] pLostAllocationCount Number of allocations marked as lost. Optional - pass null if you don't need this information. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaMakePoolAllocationsLost( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + size_t* VMA_NULLABLE pLostAllocationCount); + +/** \brief Checks magic number in margins around all allocations in given memory pool in search for corruptions. + +Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, +`VMA_DEBUG_MARGIN` is defined to nonzero and the pool is created in memory type that is +`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). + +Possible return values: + +- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for specified pool. +- `VK_SUCCESS` - corruption detection has been performed and succeeded. +- `VK_ERROR_VALIDATION_FAILED_EXT` - corruption detection has been performed and found memory corruptions around one of the allocations. + `VMA_ASSERT` is also fired in that case. +- Other value: Error returned by Vulkan, e.g. memory mapping failure. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption(VmaAllocator VMA_NOT_NULL allocator, VmaPool VMA_NOT_NULL pool); + +/** \brief Retrieves name of a custom pool. + +After the call `ppName` is either null or points to an internally-owned null-terminated string +containing name of the pool that was previously set. The pointer becomes invalid when the pool is +destroyed or its name is changed using vmaSetPoolName(). +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + const char* VMA_NULLABLE * VMA_NOT_NULL ppName); + +/** \brief Sets name of a custom pool. + +`pName` can be either null or pointer to a null-terminated string with new name for the pool. +Function makes internal copy of the string, so it can be changed or freed immediately after this call. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( + VmaAllocator VMA_NOT_NULL allocator, + VmaPool VMA_NOT_NULL pool, + const char* VMA_NULLABLE pName); + +/** \struct VmaAllocation +\brief Represents single memory allocation. + +It may be either dedicated block of `VkDeviceMemory` or a specific region of a bigger block of this type +plus unique offset. + +There are multiple ways to create such object. +You need to fill structure VmaAllocationCreateInfo. +For more information see [Choosing memory type](@ref choosing_memory_type). + +Although the library provides convenience functions that create Vulkan buffer or image, +allocate memory for it and bind them together, +binding of the allocation to a buffer or an image is out of scope of the allocation itself. +Allocation object can exist without buffer/image bound, +binding can be done manually by the user, and destruction of it can be done +independently of destruction of the allocation. + +The object also remembers its size and some other information. +To retrieve this information, use function vmaGetAllocationInfo() and inspect +returned structure VmaAllocationInfo. + +Some kinds allocations can be in lost state. +For more information, see [Lost allocations](@ref lost_allocations). +*/ +VK_DEFINE_HANDLE(VmaAllocation) + +/** \brief Parameters of #VmaAllocation objects, that can be retrieved using function vmaGetAllocationInfo(). +*/ +typedef struct VmaAllocationInfo { + /** \brief Memory type index that this allocation was allocated from. + + It never changes. + */ + uint32_t memoryType; + /** \brief Handle to Vulkan memory object. + + Same memory object can be shared by multiple allocations. + + It can change after call to vmaDefragment() if this allocation is passed to the function, or if allocation is lost. + + If the allocation is lost, it is equal to `VK_NULL_HANDLE`. + */ + VkDeviceMemory VMA_NULLABLE_NON_DISPATCHABLE deviceMemory; + /** \brief Offset into deviceMemory object to the beginning of this allocation, in bytes. (deviceMemory, offset) pair is unique to this allocation. + + It can change after call to vmaDefragment() if this allocation is passed to the function, or if allocation is lost. + */ + VkDeviceSize offset; + /** \brief Size of this allocation, in bytes. + + It never changes, unless allocation is lost. + + \note Allocation size returned in this variable may be greater than the size + requested for the resource e.g. as `VkBufferCreateInfo::size`. Whole size of the + allocation is accessible for operations on memory e.g. using a pointer after + mapping with vmaMapMemory(), but operations on the resource e.g. using + `vkCmdCopyBuffer` must be limited to the size of the resource. + */ + VkDeviceSize size; + /** \brief Pointer to the beginning of this allocation as mapped data. + + If the allocation hasn't been mapped using vmaMapMemory() and hasn't been + created with #VMA_ALLOCATION_CREATE_MAPPED_BIT flag, this value is null. + + It can change after call to vmaMapMemory(), vmaUnmapMemory(). + It can also change after call to vmaDefragment() if this allocation is passed to the function. + */ + void* VMA_NULLABLE pMappedData; + /** \brief Custom general-purpose pointer that was passed as VmaAllocationCreateInfo::pUserData or set using vmaSetAllocationUserData(). + + It can change after call to vmaSetAllocationUserData() for this allocation. + */ + void* VMA_NULLABLE pUserData; +} VmaAllocationInfo; + +/** \brief General purpose memory allocation. + +@param[out] pAllocation Handle to allocated memory. +@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). + +You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). + +It is recommended to use vmaAllocateMemoryForBuffer(), vmaAllocateMemoryForImage(), +vmaCreateBuffer(), vmaCreateImage() instead whenever possible. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( + VmaAllocator VMA_NOT_NULL allocator, + const VkMemoryRequirements* VMA_NOT_NULL pVkMemoryRequirements, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + +/** \brief General purpose memory allocation for multiple allocation objects at once. + +@param allocator Allocator object. +@param pVkMemoryRequirements Memory requirements for each allocation. +@param pCreateInfo Creation parameters for each alloction. +@param allocationCount Number of allocations to make. +@param[out] pAllocations Pointer to array that will be filled with handles to created allocations. +@param[out] pAllocationInfo Optional. Pointer to array that will be filled with parameters of created allocations. + +You should free the memory using vmaFreeMemory() or vmaFreeMemoryPages(). + +Word "pages" is just a suggestion to use this function to allocate pieces of memory needed for sparse binding. +It is just a general purpose allocation function able to make multiple allocations at once. +It may be internally optimized to be more efficient than calling vmaAllocateMemory() `allocationCount` times. + +All allocations are made using same parameters. All of them are created out of the same memory pool and type. +If any allocation fails, all allocations already made within this function call are also freed, so that when +returned result is not `VK_SUCCESS`, `pAllocation` array is always entirely filled with `VK_NULL_HANDLE`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( + VmaAllocator VMA_NOT_NULL allocator, + const VkMemoryRequirements* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pVkMemoryRequirements, + const VmaAllocationCreateInfo* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pCreateInfo, + size_t allocationCount, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations, + VmaAllocationInfo* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationInfo); + +/** +@param[out] pAllocation Handle to allocated memory. +@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). + +You should free the memory using vmaFreeMemory(). +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + +/// Function similar to vmaAllocateMemoryForBuffer(). +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( + VmaAllocator VMA_NOT_NULL allocator, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, + const VmaAllocationCreateInfo* VMA_NOT_NULL pCreateInfo, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + +/** \brief Frees memory previously allocated using vmaAllocateMemory(), vmaAllocateMemoryForBuffer(), or vmaAllocateMemoryForImage(). + +Passing `VK_NULL_HANDLE` as `allocation` is valid. Such function call is just skipped. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( + VmaAllocator VMA_NOT_NULL allocator, + const VmaAllocation VMA_NULLABLE allocation); + +/** \brief Frees memory and destroys multiple allocations. + +Word "pages" is just a suggestion to use this function to free pieces of memory used for sparse binding. +It is just a general purpose function to free memory and destroy allocations made using e.g. vmaAllocateMemory(), +vmaAllocateMemoryPages() and other functions. +It may be internally optimized to be more efficient than calling vmaFreeMemory() `allocationCount` times. + +Allocations in `pAllocations` array can come from any memory pools and types. +Passing `VK_NULL_HANDLE` as elements of `pAllocations` array is valid. Such entries are just skipped. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( + VmaAllocator VMA_NOT_NULL allocator, + size_t allocationCount, + const VmaAllocation VMA_NULLABLE * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations); + +/** \brief Deprecated. + +\deprecated +In version 2.2.0 it used to try to change allocation's size without moving or reallocating it. +In current version it returns `VK_SUCCESS` only if `newSize` equals current allocation's size. +Otherwise returns `VK_ERROR_OUT_OF_POOL_MEMORY`, indicating that allocation's size could not be changed. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaResizeAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize newSize); + +/** \brief Returns current information about specified allocation and atomically marks it as used in current frame. + +Current paramters of given allocation are returned in `pAllocationInfo`. + +This function also atomically "touches" allocation - marks it as used in current frame, +just like vmaTouchAllocation(). +If the allocation is in lost state, `pAllocationInfo->deviceMemory == VK_NULL_HANDLE`. + +Although this function uses atomics and doesn't lock any mutex, so it should be quite efficient, +you can avoid calling it too often. + +- You can retrieve same VmaAllocationInfo structure while creating your resource, from function + vmaCreateBuffer(), vmaCreateImage(). You can remember it if you are sure parameters don't change + (e.g. due to defragmentation or allocation becoming lost). +- If you just want to check if allocation is not lost, vmaTouchAllocation() will work faster. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VmaAllocationInfo* VMA_NOT_NULL pAllocationInfo); + +/** \brief Returns `VK_TRUE` if allocation is not lost and atomically marks it as used in current frame. + +If the allocation has been created with #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, +this function returns `VK_TRUE` if it's not in lost state, so it can still be used. +It then also atomically "touches" the allocation - marks it as used in current frame, +so that you can be sure it won't become lost in current frame or next `frameInUseCount` frames. + +If the allocation is in lost state, the function returns `VK_FALSE`. +Memory of such allocation, as well as buffer or image bound to it, should not be used. +Lost allocation and the buffer/image still need to be destroyed. + +If the allocation has been created without #VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag, +this function always returns `VK_TRUE`. +*/ +VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaTouchAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation); + +/** \brief Sets pUserData in given allocation to new value. + +If the allocation was created with VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT, +pUserData must be either null, or pointer to a null-terminated string. The function +makes local copy of the string and sets it as allocation's `pUserData`. String +passed as pUserData doesn't need to be valid for whole lifetime of the allocation - +you can free it after this call. String previously pointed by allocation's +pUserData is freed from memory. + +If the flag was not used, the value of pointer `pUserData` is just copied to +allocation's `pUserData`. It is opaque, so you can use it however you want - e.g. +as a pointer, ordinal number or some handle to you own data. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + void* VMA_NULLABLE pUserData); + +/** \brief Creates new allocation that is in lost state from the beginning. + +It can be useful if you need a dummy, non-null allocation. + +You still need to destroy created object using vmaFreeMemory(). + +Returned allocation is not tied to any specific memory pool or memory type and +not bound to any image or buffer. It has size = 0. It cannot be turned into +a real, non-empty allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaCreateLostAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation); + +/** \brief Maps memory represented by given allocation and returns pointer to it. + +Maps memory represented by given allocation to make it accessible to CPU code. +When succeeded, `*ppData` contains pointer to first byte of this memory. +If the allocation is part of bigger `VkDeviceMemory` block, the pointer is +correctly offseted to the beginning of region assigned to this particular +allocation. + +Mapping is internally reference-counted and synchronized, so despite raw Vulkan +function `vkMapMemory()` cannot be used to map same block of `VkDeviceMemory` +multiple times simultaneously, it is safe to call this function on allocations +assigned to the same memory block. Actual Vulkan memory will be mapped on first +mapping and unmapped on last unmapping. + +If the function succeeded, you must call vmaUnmapMemory() to unmap the +allocation when mapping is no longer needed or before freeing the allocation, at +the latest. + +It also safe to call this function multiple times on the same allocation. You +must call vmaUnmapMemory() same number of times as you called vmaMapMemory(). + +It is also safe to call this function on allocation created with +#VMA_ALLOCATION_CREATE_MAPPED_BIT flag. Its memory stays mapped all the time. +You must still call vmaUnmapMemory() same number of times as you called +vmaMapMemory(). You must not call vmaUnmapMemory() additional time to free the +"0-th" mapping made automatically due to #VMA_ALLOCATION_CREATE_MAPPED_BIT flag. + +This function fails when used on allocation made in memory type that is not +`HOST_VISIBLE`. + +This function always fails when called for allocation that was created with +#VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT flag. Such allocations cannot be +mapped. + +This function doesn't automatically flush or invalidate caches. +If the allocation is made from a memory types that is not `HOST_COHERENT`, +you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + void* VMA_NULLABLE * VMA_NOT_NULL ppData); + +/** \brief Unmaps memory represented by given allocation, mapped previously using vmaMapMemory(). + +For details, see description of vmaMapMemory(). + +This function doesn't automatically flush or invalidate caches. +If the allocation is made from a memory types that is not `HOST_COHERENT`, +you also need to use vmaInvalidateAllocation() / vmaFlushAllocation(), as required by Vulkan specification. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation); + +/** \brief Flushes memory of given allocation. + +Calls `vkFlushMappedMemoryRanges()` for memory associated with given range of given allocation. +It needs to be called after writing to a mapped memory for memory types that are not `HOST_COHERENT`. +Unmap operation doesn't do that automatically. + +- `offset` must be relative to the beginning of allocation. +- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. +- `offset` and `size` don't have to be aligned. + They are internally rounded down/up to multiply of `nonCoherentAtomSize`. +- If `size` is 0, this call is ignored. +- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, + this call is ignored. + +Warning! `offset` and `size` are relative to the contents of given `allocation`. +If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. +Do not pass allocation's offset as `offset`!!! + +This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize offset, + VkDeviceSize size); + +/** \brief Invalidates memory of given allocation. + +Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given range of given allocation. +It needs to be called before reading from a mapped memory for memory types that are not `HOST_COHERENT`. +Map operation doesn't do that automatically. + +- `offset` must be relative to the beginning of allocation. +- `size` can be `VK_WHOLE_SIZE`. It means all memory from `offset` the the end of given allocation. +- `offset` and `size` don't have to be aligned. + They are internally rounded down/up to multiply of `nonCoherentAtomSize`. +- If `size` is 0, this call is ignored. +- If memory type that the `allocation` belongs to is not `HOST_VISIBLE` or it is `HOST_COHERENT`, + this call is ignored. + +Warning! `offset` and `size` are relative to the contents of given `allocation`. +If you mean whole allocation, you can pass 0 and `VK_WHOLE_SIZE`, respectively. +Do not pass allocation's offset as `offset`!!! + +This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if +it is called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize offset, + VkDeviceSize size); + +/** \brief Flushes memory of given set of allocations. + +Calls `vkFlushMappedMemoryRanges()` for memory associated with given ranges of given allocations. +For more information, see documentation of vmaFlushAllocation(). + +\param allocator +\param allocationCount +\param allocations +\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. +\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. + +This function returns the `VkResult` from `vkFlushMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t allocationCount, + const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); + +/** \brief Invalidates memory of given set of allocations. + +Calls `vkInvalidateMappedMemoryRanges()` for memory associated with given ranges of given allocations. +For more information, see documentation of vmaInvalidateAllocation(). + +\param allocator +\param allocationCount +\param allocations +\param offsets If not null, it must point to an array of offsets of regions to flush, relative to the beginning of respective allocations. Null means all ofsets are zero. +\param sizes If not null, it must point to an array of sizes of regions to flush in respective allocations. Null means `VK_WHOLE_SIZE` for all allocations. + +This function returns the `VkResult` from `vkInvalidateMappedMemoryRanges` if it is +called, otherwise `VK_SUCCESS`. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( + VmaAllocator VMA_NOT_NULL allocator, + uint32_t allocationCount, + const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) allocations, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) offsets, + const VkDeviceSize* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) sizes); + +/** \brief Checks magic number in margins around all allocations in given memory types (in both default and custom pools) in search for corruptions. + +@param memoryTypeBits Bit mask, where each bit set means that a memory type with that index should be checked. + +Corruption detection is enabled only when `VMA_DEBUG_DETECT_CORRUPTION` macro is defined to nonzero, +`VMA_DEBUG_MARGIN` is defined to nonzero and only for memory types that are +`HOST_VISIBLE` and `HOST_COHERENT`. For more information, see [Corruption detection](@ref debugging_memory_usage_corruption_detection). + +Possible return values: + +- `VK_ERROR_FEATURE_NOT_PRESENT` - corruption detection is not enabled for any of specified memory types. +- `VK_SUCCESS` - corruption detection has been performed and succeeded. +- `VK_ERROR_VALIDATION_FAILED_EXT` - corruption detection has been performed and found memory corruptions around one of the allocations. + `VMA_ASSERT` is also fired in that case. +- Other value: Error returned by Vulkan, e.g. memory mapping failure. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption(VmaAllocator VMA_NOT_NULL allocator, uint32_t memoryTypeBits); + +/** \struct VmaDefragmentationContext +\brief Represents Opaque object that represents started defragmentation process. + +Fill structure #VmaDefragmentationInfo2 and call function vmaDefragmentationBegin() to create it. +Call function vmaDefragmentationEnd() to destroy it. +*/ +VK_DEFINE_HANDLE(VmaDefragmentationContext) + +/// Flags to be used in vmaDefragmentationBegin(). None at the moment. Reserved for future use. +typedef enum VmaDefragmentationFlagBits { + VMA_DEFRAGMENTATION_FLAG_INCREMENTAL = 0x1, + VMA_DEFRAGMENTATION_FLAG_BITS_MAX_ENUM = 0x7FFFFFFF +} VmaDefragmentationFlagBits; +typedef VkFlags VmaDefragmentationFlags; + +/** \brief Parameters for defragmentation. + +To be used with function vmaDefragmentationBegin(). +*/ +typedef struct VmaDefragmentationInfo2 { + /** \brief Reserved for future use. Should be 0. + */ + VmaDefragmentationFlags flags; + /** \brief Number of allocations in `pAllocations` array. + */ + uint32_t allocationCount; + /** \brief Pointer to array of allocations that can be defragmented. + + The array should have `allocationCount` elements. + The array should not contain nulls. + Elements in the array should be unique - same allocation cannot occur twice. + It is safe to pass allocations that are in the lost state - they are ignored. + All allocations not present in this array are considered non-moveable during this defragmentation. + */ + const VmaAllocation VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations; + /** \brief Optional, output. Pointer to array that will be filled with information whether the allocation at certain index has been changed during defragmentation. + + The array should have `allocationCount` elements. + You can pass null if you are not interested in this information. + */ + VkBool32* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationsChanged; + /** \brief Numer of pools in `pPools` array. + */ + uint32_t poolCount; + /** \brief Either null or pointer to array of pools to be defragmented. + + All the allocations in the specified pools can be moved during defragmentation + and there is no way to check if they were really moved as in `pAllocationsChanged`, + so you must query all the allocations in all these pools for new `VkDeviceMemory` + and offset using vmaGetAllocationInfo() if you might need to recreate buffers + and images bound to them. + + The array should have `poolCount` elements. + The array should not contain nulls. + Elements in the array should be unique - same pool cannot occur twice. + + Using this array is equivalent to specifying all allocations from the pools in `pAllocations`. + It might be more efficient. + */ + const VmaPool VMA_NOT_NULL * VMA_NULLABLE VMA_LEN_IF_NOT_NULL(poolCount) pPools; + /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places using transfers on CPU side, like `memcpy()`, `memmove()`. + + `VK_WHOLE_SIZE` means no limit. + */ + VkDeviceSize maxCpuBytesToMove; + /** \brief Maximum number of allocations that can be moved to a different place using transfers on CPU side, like `memcpy()`, `memmove()`. + + `UINT32_MAX` means no limit. + */ + uint32_t maxCpuAllocationsToMove; + /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places using transfers on GPU side, posted to `commandBuffer`. + + `VK_WHOLE_SIZE` means no limit. + */ + VkDeviceSize maxGpuBytesToMove; + /** \brief Maximum number of allocations that can be moved to a different place using transfers on GPU side, posted to `commandBuffer`. + + `UINT32_MAX` means no limit. + */ + uint32_t maxGpuAllocationsToMove; + /** \brief Optional. Command buffer where GPU copy commands will be posted. + + If not null, it must be a valid command buffer handle that supports Transfer queue type. + It must be in the recording state and outside of a render pass instance. + You need to submit it and make sure it finished execution before calling vmaDefragmentationEnd(). + + Passing null means that only CPU defragmentation will be performed. + */ + VkCommandBuffer VMA_NULLABLE commandBuffer; +} VmaDefragmentationInfo2; + +typedef struct VmaDefragmentationPassMoveInfo { + VmaAllocation VMA_NOT_NULL allocation; + VkDeviceMemory VMA_NOT_NULL_NON_DISPATCHABLE memory; + VkDeviceSize offset; +} VmaDefragmentationPassMoveInfo; + +/** \brief Parameters for incremental defragmentation steps. + +To be used with function vmaBeginDefragmentationPass(). +*/ +typedef struct VmaDefragmentationPassInfo { + uint32_t moveCount; + VmaDefragmentationPassMoveInfo* VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(moveCount) pMoves; +} VmaDefragmentationPassInfo; + +/** \brief Deprecated. Optional configuration parameters to be passed to function vmaDefragment(). + +\deprecated This is a part of the old interface. It is recommended to use structure #VmaDefragmentationInfo2 and function vmaDefragmentationBegin() instead. +*/ +typedef struct VmaDefragmentationInfo { + /** \brief Maximum total numbers of bytes that can be copied while moving allocations to different places. + + Default is `VK_WHOLE_SIZE`, which means no limit. + */ + VkDeviceSize maxBytesToMove; + /** \brief Maximum number of allocations that can be moved to different place. + + Default is `UINT32_MAX`, which means no limit. + */ + uint32_t maxAllocationsToMove; +} VmaDefragmentationInfo; + +/** \brief Statistics returned by function vmaDefragment(). */ +typedef struct VmaDefragmentationStats { + /// Total number of bytes that have been copied while moving allocations to different places. + VkDeviceSize bytesMoved; + /// Total number of bytes that have been released to the system by freeing empty `VkDeviceMemory` objects. + VkDeviceSize bytesFreed; + /// Number of allocations that have been moved to different places. + uint32_t allocationsMoved; + /// Number of empty `VkDeviceMemory` objects that have been released to the system. + uint32_t deviceMemoryBlocksFreed; +} VmaDefragmentationStats; + +/** \brief Begins defragmentation process. + +@param allocator Allocator object. +@param pInfo Structure filled with parameters of defragmentation. +@param[out] pStats Optional. Statistics of defragmentation. You can pass null if you are not interested in this information. +@param[out] pContext Context object that must be passed to vmaDefragmentationEnd() to finish defragmentation. +@return `VK_SUCCESS` and `*pContext == null` if defragmentation finished within this function call. `VK_NOT_READY` and `*pContext != null` if defragmentation has been started and you need to call vmaDefragmentationEnd() to finish it. Negative value in case of error. + +Use this function instead of old, deprecated vmaDefragment(). + +Warning! Between the call to vmaDefragmentationBegin() and vmaDefragmentationEnd(): + +- You should not use any of allocations passed as `pInfo->pAllocations` or + any allocations that belong to pools passed as `pInfo->pPools`, + including calling vmaGetAllocationInfo(), vmaTouchAllocation(), or access + their data. +- Some mutexes protecting internal data structures may be locked, so trying to + make or free any allocations, bind buffers or images, map memory, or launch + another simultaneous defragmentation in between may cause stall (when done on + another thread) or deadlock (when done on the same thread), unless you are + 100% sure that defragmented allocations are in different pools. +- Information returned via `pStats` and `pInfo->pAllocationsChanged` are undefined. + They become valid after call to vmaDefragmentationEnd(). +- If `pInfo->commandBuffer` is not null, you must submit that command buffer + and make sure it finished execution before calling vmaDefragmentationEnd(). + +For more information and important limitations regarding defragmentation, see documentation chapter: +[Defragmentation](@ref defragmentation). +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationBegin( + VmaAllocator VMA_NOT_NULL allocator, + const VmaDefragmentationInfo2* VMA_NOT_NULL pInfo, + VmaDefragmentationStats* VMA_NULLABLE pStats, + VmaDefragmentationContext VMA_NULLABLE * VMA_NOT_NULL pContext); + +/** \brief Ends defragmentation process. + +Use this function to finish defragmentation started by vmaDefragmentationBegin(). +It is safe to pass `context == null`. The function then does nothing. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationEnd( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NULLABLE context); + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NULLABLE context, + VmaDefragmentationPassInfo* VMA_NOT_NULL pInfo +); +VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( + VmaAllocator VMA_NOT_NULL allocator, + VmaDefragmentationContext VMA_NULLABLE context +); + +/** \brief Deprecated. Compacts memory by moving allocations. + +@param pAllocations Array of allocations that can be moved during this compation. +@param allocationCount Number of elements in pAllocations and pAllocationsChanged arrays. +@param[out] pAllocationsChanged Array of boolean values that will indicate whether matching allocation in pAllocations array has been moved. This parameter is optional. Pass null if you don't need this information. +@param pDefragmentationInfo Configuration parameters. Optional - pass null to use default values. +@param[out] pDefragmentationStats Statistics returned by the function. Optional - pass null if you don't need this information. +@return `VK_SUCCESS` if completed, negative error code in case of error. + +\deprecated This is a part of the old interface. It is recommended to use structure #VmaDefragmentationInfo2 and function vmaDefragmentationBegin() instead. + +This function works by moving allocations to different places (different +`VkDeviceMemory` objects and/or different offsets) in order to optimize memory +usage. Only allocations that are in `pAllocations` array can be moved. All other +allocations are considered nonmovable in this call. Basic rules: + +- Only allocations made in memory types that have + `VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT` and `VK_MEMORY_PROPERTY_HOST_COHERENT_BIT` + flags can be compacted. You may pass other allocations but it makes no sense - + these will never be moved. +- Custom pools created with #VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT or + #VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT flag are not defragmented. Allocations + passed to this function that come from such pools are ignored. +- Allocations created with #VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT or + created as dedicated allocations for any other reason are also ignored. +- Both allocations made with or without #VMA_ALLOCATION_CREATE_MAPPED_BIT + flag can be compacted. If not persistently mapped, memory will be mapped + temporarily inside this function if needed. +- You must not pass same #VmaAllocation object multiple times in `pAllocations` array. + +The function also frees empty `VkDeviceMemory` blocks. + +Warning: This function may be time-consuming, so you shouldn't call it too often +(like after every resource creation/destruction). +You can call it on special occasions (like when reloading a game level or +when you just destroyed a lot of objects). Calling it every frame may be OK, but +you should measure that on your platform. + +For more information, see [Defragmentation](@ref defragmentation) chapter. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragment( + VmaAllocator VMA_NOT_NULL allocator, + const VmaAllocation VMA_NOT_NULL * VMA_NOT_NULL VMA_LEN_IF_NOT_NULL(allocationCount) pAllocations, + size_t allocationCount, + VkBool32* VMA_NULLABLE VMA_LEN_IF_NOT_NULL(allocationCount) pAllocationsChanged, + const VmaDefragmentationInfo* VMA_NULLABLE pDefragmentationInfo, + VmaDefragmentationStats* VMA_NULLABLE pDefragmentationStats); + +/** \brief Binds buffer to allocation. + +Binds specified buffer to region of memory represented by specified allocation. +Gets `VkDeviceMemory` handle and offset from the allocation. +If you want to create a buffer, allocate memory for it and bind them together separately, +you should use this function for binding instead of standard `vkBindBufferMemory()`, +because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple +allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously +(which is illegal in Vulkan). + +It is recommended to use function vmaCreateBuffer() instead of this one. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer); + +/** \brief Binds buffer to allocation with additional parameters. + +@param allocationLocalOffset Additional offset to be added while binding, relative to the beginnig of the `allocation`. Normally it should be 0. +@param pNext A chain of structures to be attached to `VkBindBufferMemoryInfoKHR` structure used internally. Normally it should be null. + +This function is similar to vmaBindBufferMemory(), but it provides additional parameters. + +If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag +or with VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_1`. Otherwise the call fails. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize allocationLocalOffset, + VkBuffer VMA_NOT_NULL_NON_DISPATCHABLE buffer, + const void* VMA_NULLABLE pNext); + +/** \brief Binds image to allocation. + +Binds specified image to region of memory represented by specified allocation. +Gets `VkDeviceMemory` handle and offset from the allocation. +If you want to create an image, allocate memory for it and bind them together separately, +you should use this function for binding instead of standard `vkBindImageMemory()`, +because it ensures proper synchronization so that when a `VkDeviceMemory` object is used by multiple +allocations, calls to `vkBind*Memory()` or `vkMapMemory()` won't happen from multiple threads simultaneously +(which is illegal in Vulkan). + +It is recommended to use function vmaCreateImage() instead of this one. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image); + +/** \brief Binds image to allocation with additional parameters. + +@param allocationLocalOffset Additional offset to be added while binding, relative to the beginnig of the `allocation`. Normally it should be 0. +@param pNext A chain of structures to be attached to `VkBindImageMemoryInfoKHR` structure used internally. Normally it should be null. + +This function is similar to vmaBindImageMemory(), but it provides additional parameters. + +If `pNext` is not null, #VmaAllocator object must have been created with #VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT flag +or with VmaAllocatorCreateInfo::vulkanApiVersion `== VK_API_VERSION_1_1`. Otherwise the call fails. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( + VmaAllocator VMA_NOT_NULL allocator, + VmaAllocation VMA_NOT_NULL allocation, + VkDeviceSize allocationLocalOffset, + VkImage VMA_NOT_NULL_NON_DISPATCHABLE image, + const void* VMA_NULLABLE pNext); + +/** +@param[out] pBuffer Buffer that was created. +@param[out] pAllocation Allocation that was created. +@param[out] pAllocationInfo Optional. Information about allocated memory. It can be later fetched using function vmaGetAllocationInfo(). + +This function automatically: + +-# Creates buffer. +-# Allocates appropriate memory for it. +-# Binds the buffer with the memory. + +If any of these operations fail, buffer and allocation are not created, +returned value is negative error code, *pBuffer and *pAllocation are null. + +If the function succeeded, you must destroy both buffer and allocation when you +no longer need them using either convenience function vmaDestroyBuffer() or +separately, using `vkDestroyBuffer()` and vmaFreeMemory(). + +If VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT flag was used, +VK_KHR_dedicated_allocation extension is used internally to query driver whether +it requires or prefers the new buffer to have dedicated allocation. If yes, +and if dedicated allocation is possible (VmaAllocationCreateInfo::pool is null +and VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT is not used), it creates dedicated +allocation for this buffer, just like when using +VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( + VmaAllocator VMA_NOT_NULL allocator, + const VkBufferCreateInfo* VMA_NOT_NULL pBufferCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE * VMA_NOT_NULL pBuffer, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + +/** \brief Destroys Vulkan buffer and frees allocated memory. + +This is just a convenience function equivalent to: + +\code +vkDestroyBuffer(device, buffer, allocationCallbacks); +vmaFreeMemory(allocator, allocation); +\endcode + +It it safe to pass null as buffer and/or allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( + VmaAllocator VMA_NOT_NULL allocator, + VkBuffer VMA_NULLABLE_NON_DISPATCHABLE buffer, + VmaAllocation VMA_NULLABLE allocation); + +/// Function similar to vmaCreateBuffer(). +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( + VmaAllocator VMA_NOT_NULL allocator, + const VkImageCreateInfo* VMA_NOT_NULL pImageCreateInfo, + const VmaAllocationCreateInfo* VMA_NOT_NULL pAllocationCreateInfo, + VkImage VMA_NULLABLE_NON_DISPATCHABLE * VMA_NOT_NULL pImage, + VmaAllocation VMA_NULLABLE * VMA_NOT_NULL pAllocation, + VmaAllocationInfo* VMA_NULLABLE pAllocationInfo); + +/** \brief Destroys Vulkan image and frees allocated memory. + +This is just a convenience function equivalent to: + +\code +vkDestroyImage(device, image, allocationCallbacks); +vmaFreeMemory(allocator, allocation); +\endcode + +It it safe to pass null as image and/or allocation. +*/ +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( + VmaAllocator VMA_NOT_NULL allocator, + VkImage VMA_NULLABLE_NON_DISPATCHABLE image, + VmaAllocation VMA_NULLABLE allocation); + +#ifdef __cplusplus +} +#endif + +#endif // AMD_VULKAN_MEMORY_ALLOCATOR_H + +// For Visual Studio IntelliSense. +#if defined(__cplusplus) && defined(__INTELLISENSE__) +#define VMA_IMPLEMENTATION +#endif + +#ifdef VMA_IMPLEMENTATION +#undef VMA_IMPLEMENTATION + +#include +#include +#include +#include + +/******************************************************************************* +CONFIGURATION SECTION + +Define some of these macros before each #include of this header or change them +here if you need other then default behavior depending on your environment. +*/ + +/* +Define this macro to 1 to make the library fetch pointers to Vulkan functions +internally, like: + + vulkanFunctions.vkAllocateMemory = &vkAllocateMemory; +*/ +#if !defined(VMA_STATIC_VULKAN_FUNCTIONS) && !defined(VK_NO_PROTOTYPES) + #define VMA_STATIC_VULKAN_FUNCTIONS 1 +#endif + +/* +Define this macro to 1 to make the library fetch pointers to Vulkan functions +internally, like: + + vulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkGetDeviceProcAddr(m_hDevice, vkAllocateMemory); +*/ +#if !defined(VMA_DYNAMIC_VULKAN_FUNCTIONS) + #define VMA_DYNAMIC_VULKAN_FUNCTIONS 1 +#endif + +// Define this macro to 1 to make the library use STL containers instead of its own implementation. +//#define VMA_USE_STL_CONTAINERS 1 + +/* Set this macro to 1 to make the library including and using STL containers: +std::pair, std::vector, std::list, std::unordered_map. + +Set it to 0 or undefined to make the library using its own implementation of +the containers. +*/ +#if VMA_USE_STL_CONTAINERS + #define VMA_USE_STL_VECTOR 1 + #define VMA_USE_STL_UNORDERED_MAP 1 + #define VMA_USE_STL_LIST 1 +#endif + +#ifndef VMA_USE_STL_SHARED_MUTEX + // Compiler conforms to C++17. + #if __cplusplus >= 201703L + #define VMA_USE_STL_SHARED_MUTEX 1 + // Visual studio defines __cplusplus properly only when passed additional parameter: /Zc:__cplusplus + // Otherwise it's always 199711L, despite shared_mutex works since Visual Studio 2015 Update 2. + // See: https://blogs.msdn.microsoft.com/vcblog/2018/04/09/msvc-now-correctly-reports-__cplusplus/ + #elif defined(_MSC_FULL_VER) && _MSC_FULL_VER >= 190023918 && __cplusplus == 199711L && _MSVC_LANG >= 201703L + #define VMA_USE_STL_SHARED_MUTEX 1 + #else + #define VMA_USE_STL_SHARED_MUTEX 0 + #endif +#endif + +/* +THESE INCLUDES ARE NOT ENABLED BY DEFAULT. +Library has its own container implementation. +*/ +#if VMA_USE_STL_VECTOR + #include +#endif + +#if VMA_USE_STL_UNORDERED_MAP + #include +#endif + +#if VMA_USE_STL_LIST + #include +#endif + +/* +Following headers are used in this CONFIGURATION section only, so feel free to +remove them if not needed. +*/ +#include // for assert +#include // for min, max +#include + +#ifndef VMA_NULL + // Value used as null pointer. Define it to e.g.: nullptr, NULL, 0, (void*)0. + #define VMA_NULL nullptr +#endif + +#if defined(__ANDROID_API__) && (__ANDROID_API__ < 16) +#include +void *vma_aligned_alloc(size_t alignment, size_t size) +{ + // alignment must be >= sizeof(void*) + if(alignment < sizeof(void*)) + { + alignment = sizeof(void*); + } + + return memalign(alignment, size); +} +#elif defined(__APPLE__) || defined(__ANDROID__) || (defined(__linux__) && defined(__GLIBCXX__) && !defined(_GLIBCXX_HAVE_ALIGNED_ALLOC)) +#include + +#if defined(__APPLE__) +#include +#endif + +void *vma_aligned_alloc(size_t alignment, size_t size) +{ +#if defined(__APPLE__) && (defined(MAC_OS_X_VERSION_10_16) || defined(__IPHONE_14_0)) +#if MAC_OS_X_VERSION_MAX_ALLOWED >= MAC_OS_X_VERSION_10_16 || __IPHONE_OS_VERSION_MAX_ALLOWED >= __IPHONE_14_0 + // For C++14, usr/include/malloc/_malloc.h declares aligned_alloc()) only + // with the MacOSX11.0 SDK in Xcode 12 (which is what adds + // MAC_OS_X_VERSION_10_16), even though the function is marked + // availabe for 10.15. That's why the preprocessor checks for 10.16 but + // the __builtin_available checks for 10.15. + // People who use C++17 could call aligned_alloc with the 10.15 SDK already. + if (__builtin_available(macOS 10.15, iOS 13, *)) + return aligned_alloc(alignment, size); +#endif +#endif + // alignment must be >= sizeof(void*) + if(alignment < sizeof(void*)) + { + alignment = sizeof(void*); + } + + void *pointer; + if(posix_memalign(&pointer, alignment, size) == 0) + return pointer; + return VMA_NULL; +} +#elif defined(_WIN32) +void *vma_aligned_alloc(size_t alignment, size_t size) +{ + return _aligned_malloc(size, alignment); +} +#else +void *vma_aligned_alloc(size_t alignment, size_t size) +{ + return aligned_alloc(alignment, size); +} +#endif + +// If your compiler is not compatible with C++11 and definition of +// aligned_alloc() function is missing, uncommeting following line may help: + +//#include + +// Normal assert to check for programmer's errors, especially in Debug configuration. +#ifndef VMA_ASSERT + #ifdef NDEBUG + #define VMA_ASSERT(expr) + #else + #define VMA_ASSERT(expr) assert(expr) + #endif +#endif + +// Assert that will be called very often, like inside data structures e.g. operator[]. +// Making it non-empty can make program slow. +#ifndef VMA_HEAVY_ASSERT + #ifdef NDEBUG + #define VMA_HEAVY_ASSERT(expr) + #else + #define VMA_HEAVY_ASSERT(expr) //VMA_ASSERT(expr) + #endif +#endif + +#ifndef VMA_ALIGN_OF + #define VMA_ALIGN_OF(type) (__alignof(type)) +#endif + +#ifndef VMA_SYSTEM_ALIGNED_MALLOC + #define VMA_SYSTEM_ALIGNED_MALLOC(size, alignment) vma_aligned_alloc((alignment), (size)) +#endif + +#ifndef VMA_SYSTEM_FREE + #if defined(_WIN32) + #define VMA_SYSTEM_FREE(ptr) _aligned_free(ptr) + #else + #define VMA_SYSTEM_FREE(ptr) free(ptr) + #endif +#endif + +#ifndef VMA_MIN + #define VMA_MIN(v1, v2) (std::min((v1), (v2))) +#endif + +#ifndef VMA_MAX + #define VMA_MAX(v1, v2) (std::max((v1), (v2))) +#endif + +#ifndef VMA_SWAP + #define VMA_SWAP(v1, v2) std::swap((v1), (v2)) +#endif + +#ifndef VMA_SORT + #define VMA_SORT(beg, end, cmp) std::sort(beg, end, cmp) +#endif + +#ifndef VMA_DEBUG_LOG + #define VMA_DEBUG_LOG(format, ...) + /* + #define VMA_DEBUG_LOG(format, ...) do { \ + printf(format, __VA_ARGS__); \ + printf("\n"); \ + } while(false) + */ +#endif + +// Define this macro to 1 to enable functions: vmaBuildStatsString, vmaFreeStatsString. +#if VMA_STATS_STRING_ENABLED + static inline void VmaUint32ToStr(char* outStr, size_t strLen, uint32_t num) + { + snprintf(outStr, strLen, "%u", static_cast(num)); + } + static inline void VmaUint64ToStr(char* outStr, size_t strLen, uint64_t num) + { + snprintf(outStr, strLen, "%llu", static_cast(num)); + } + static inline void VmaPtrToStr(char* outStr, size_t strLen, const void* ptr) + { + snprintf(outStr, strLen, "%p", ptr); + } +#endif + +#ifndef VMA_MUTEX + class VmaMutex + { + public: + void Lock() { m_Mutex.lock(); } + void Unlock() { m_Mutex.unlock(); } + bool TryLock() { return m_Mutex.try_lock(); } + private: + std::mutex m_Mutex; + }; + #define VMA_MUTEX VmaMutex +#endif + +// Read-write mutex, where "read" is shared access, "write" is exclusive access. +#ifndef VMA_RW_MUTEX + #if VMA_USE_STL_SHARED_MUTEX + // Use std::shared_mutex from C++17. + #include + class VmaRWMutex + { + public: + void LockRead() { m_Mutex.lock_shared(); } + void UnlockRead() { m_Mutex.unlock_shared(); } + bool TryLockRead() { return m_Mutex.try_lock_shared(); } + void LockWrite() { m_Mutex.lock(); } + void UnlockWrite() { m_Mutex.unlock(); } + bool TryLockWrite() { return m_Mutex.try_lock(); } + private: + std::shared_mutex m_Mutex; + }; + #define VMA_RW_MUTEX VmaRWMutex + #elif defined(_WIN32) && defined(WINVER) && WINVER >= 0x0600 + // Use SRWLOCK from WinAPI. + // Minimum supported client = Windows Vista, server = Windows Server 2008. + class VmaRWMutex + { + public: + VmaRWMutex() { InitializeSRWLock(&m_Lock); } + void LockRead() { AcquireSRWLockShared(&m_Lock); } + void UnlockRead() { ReleaseSRWLockShared(&m_Lock); } + bool TryLockRead() { return TryAcquireSRWLockShared(&m_Lock) != FALSE; } + void LockWrite() { AcquireSRWLockExclusive(&m_Lock); } + void UnlockWrite() { ReleaseSRWLockExclusive(&m_Lock); } + bool TryLockWrite() { return TryAcquireSRWLockExclusive(&m_Lock) != FALSE; } + private: + SRWLOCK m_Lock; + }; + #define VMA_RW_MUTEX VmaRWMutex + #else + // Less efficient fallback: Use normal mutex. + class VmaRWMutex + { + public: + void LockRead() { m_Mutex.Lock(); } + void UnlockRead() { m_Mutex.Unlock(); } + bool TryLockRead() { return m_Mutex.TryLock(); } + void LockWrite() { m_Mutex.Lock(); } + void UnlockWrite() { m_Mutex.Unlock(); } + bool TryLockWrite() { return m_Mutex.TryLock(); } + private: + VMA_MUTEX m_Mutex; + }; + #define VMA_RW_MUTEX VmaRWMutex + #endif // #if VMA_USE_STL_SHARED_MUTEX +#endif // #ifndef VMA_RW_MUTEX + +/* +If providing your own implementation, you need to implement a subset of std::atomic. +*/ +#ifndef VMA_ATOMIC_UINT32 + #include + #define VMA_ATOMIC_UINT32 std::atomic +#endif + +#ifndef VMA_ATOMIC_UINT64 + #include + #define VMA_ATOMIC_UINT64 std::atomic +#endif + +#ifndef VMA_DEBUG_ALWAYS_DEDICATED_MEMORY + /** + Every allocation will have its own memory block. + Define to 1 for debugging purposes only. + */ + #define VMA_DEBUG_ALWAYS_DEDICATED_MEMORY (0) +#endif + +#ifndef VMA_DEBUG_ALIGNMENT + /** + Minimum alignment of all allocations, in bytes. + Set to more than 1 for debugging purposes only. Must be power of two. + */ + #define VMA_DEBUG_ALIGNMENT (1) +#endif + +#ifndef VMA_DEBUG_MARGIN + /** + Minimum margin before and after every allocation, in bytes. + Set nonzero for debugging purposes only. + */ + #define VMA_DEBUG_MARGIN (0) +#endif + +#ifndef VMA_DEBUG_INITIALIZE_ALLOCATIONS + /** + Define this macro to 1 to automatically fill new allocations and destroyed + allocations with some bit pattern. + */ + #define VMA_DEBUG_INITIALIZE_ALLOCATIONS (0) +#endif + +#ifndef VMA_DEBUG_DETECT_CORRUPTION + /** + Define this macro to 1 together with non-zero value of VMA_DEBUG_MARGIN to + enable writing magic value to the margin before and after every allocation and + validating it, so that memory corruptions (out-of-bounds writes) are detected. + */ + #define VMA_DEBUG_DETECT_CORRUPTION (0) +#endif + +#ifndef VMA_DEBUG_GLOBAL_MUTEX + /** + Set this to 1 for debugging purposes only, to enable single mutex protecting all + entry calls to the library. Can be useful for debugging multithreading issues. + */ + #define VMA_DEBUG_GLOBAL_MUTEX (0) +#endif + +#ifndef VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY + /** + Minimum value for VkPhysicalDeviceLimits::bufferImageGranularity. + Set to more than 1 for debugging purposes only. Must be power of two. + */ + #define VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY (1) +#endif + +#ifndef VMA_SMALL_HEAP_MAX_SIZE + /// Maximum size of a memory heap in Vulkan to consider it "small". + #define VMA_SMALL_HEAP_MAX_SIZE (1024ull * 1024 * 1024) +#endif + +#ifndef VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE + /// Default size of a block allocated as single VkDeviceMemory from a "large" heap. + #define VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE (256ull * 1024 * 1024) +#endif + +#ifndef VMA_CLASS_NO_COPY + #define VMA_CLASS_NO_COPY(className) \ + private: \ + className(const className&) = delete; \ + className& operator=(const className&) = delete; +#endif + +static const uint32_t VMA_FRAME_INDEX_LOST = UINT32_MAX; + +// Decimal 2139416166, float NaN, little-endian binary 66 E6 84 7F. +static const uint32_t VMA_CORRUPTION_DETECTION_MAGIC_VALUE = 0x7F84E666; + +static const uint8_t VMA_ALLOCATION_FILL_PATTERN_CREATED = 0xDC; +static const uint8_t VMA_ALLOCATION_FILL_PATTERN_DESTROYED = 0xEF; + +/******************************************************************************* +END OF CONFIGURATION +*/ + +// # Copy of some Vulkan definitions so we don't need to check their existence just to handle few constants. + +static const uint32_t VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY = 0x00000040; +static const uint32_t VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY = 0x00000080; +static const uint32_t VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY = 0x00020000; + +static const uint32_t VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET = 0x10000000u; + +static VkAllocationCallbacks VmaEmptyAllocationCallbacks = { + VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL, VMA_NULL }; + +// Returns number of bits set to 1 in (v). +static inline uint32_t VmaCountBitsSet(uint32_t v) +{ + uint32_t c = v - ((v >> 1) & 0x55555555); + c = ((c >> 2) & 0x33333333) + (c & 0x33333333); + c = ((c >> 4) + c) & 0x0F0F0F0F; + c = ((c >> 8) + c) & 0x00FF00FF; + c = ((c >> 16) + c) & 0x0000FFFF; + return c; +} + +/* +Returns true if given number is a power of two. +T must be unsigned integer number or signed integer but always nonnegative. +For 0 returns true. +*/ +template +inline bool VmaIsPow2(T x) +{ + return (x & (x-1)) == 0; +} + +// Aligns given value up to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 16. +// Use types like uint32_t, uint64_t as T. +template +static inline T VmaAlignUp(T val, T alignment) +{ + VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); + return (val + alignment - 1) & ~(alignment - 1); +} +// Aligns given value down to nearest multiply of align value. For example: VmaAlignUp(11, 8) = 8. +// Use types like uint32_t, uint64_t as T. +template +static inline T VmaAlignDown(T val, T alignment) +{ + VMA_HEAVY_ASSERT(VmaIsPow2(alignment)); + return val & ~(alignment - 1); +} + +// Division with mathematical rounding to nearest number. +template +static inline T VmaRoundDiv(T x, T y) +{ + return (x + (y / (T)2)) / y; +} + +// Returns smallest power of 2 greater or equal to v. +static inline uint32_t VmaNextPow2(uint32_t v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v++; + return v; +} +static inline uint64_t VmaNextPow2(uint64_t v) +{ + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v++; + return v; +} + +// Returns largest power of 2 less or equal to v. +static inline uint32_t VmaPrevPow2(uint32_t v) +{ + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v = v ^ (v >> 1); + return v; +} +static inline uint64_t VmaPrevPow2(uint64_t v) +{ + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; + v |= v >> 32; + v = v ^ (v >> 1); + return v; +} + +static inline bool VmaStrIsEmpty(const char* pStr) +{ + return pStr == VMA_NULL || *pStr == '\0'; +} + +#if VMA_STATS_STRING_ENABLED + +static const char* VmaAlgorithmToStr(uint32_t algorithm) +{ + switch(algorithm) + { + case VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT: + return "Linear"; + case VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT: + return "Buddy"; + case 0: + return "Default"; + default: + VMA_ASSERT(0); + return ""; + } +} + +#endif // #if VMA_STATS_STRING_ENABLED + +#ifndef VMA_SORT + +template +Iterator VmaQuickSortPartition(Iterator beg, Iterator end, Compare cmp) +{ + Iterator centerValue = end; --centerValue; + Iterator insertIndex = beg; + for(Iterator memTypeIndex = beg; memTypeIndex < centerValue; ++memTypeIndex) + { + if(cmp(*memTypeIndex, *centerValue)) + { + if(insertIndex != memTypeIndex) + { + VMA_SWAP(*memTypeIndex, *insertIndex); + } + ++insertIndex; + } + } + if(insertIndex != centerValue) + { + VMA_SWAP(*insertIndex, *centerValue); + } + return insertIndex; +} + +template +void VmaQuickSort(Iterator beg, Iterator end, Compare cmp) +{ + if(beg < end) + { + Iterator it = VmaQuickSortPartition(beg, end, cmp); + VmaQuickSort(beg, it, cmp); + VmaQuickSort(it + 1, end, cmp); + } +} + +#define VMA_SORT(beg, end, cmp) VmaQuickSort(beg, end, cmp) + +#endif // #ifndef VMA_SORT + +/* +Returns true if two memory blocks occupy overlapping pages. +ResourceA must be in less memory offset than ResourceB. + +Algorithm is based on "Vulkan 1.0.39 - A Specification (with all registered Vulkan extensions)" +chapter 11.6 "Resource Memory Association", paragraph "Buffer-Image Granularity". +*/ +static inline bool VmaBlocksOnSamePage( + VkDeviceSize resourceAOffset, + VkDeviceSize resourceASize, + VkDeviceSize resourceBOffset, + VkDeviceSize pageSize) +{ + VMA_ASSERT(resourceAOffset + resourceASize <= resourceBOffset && resourceASize > 0 && pageSize > 0); + VkDeviceSize resourceAEnd = resourceAOffset + resourceASize - 1; + VkDeviceSize resourceAEndPage = resourceAEnd & ~(pageSize - 1); + VkDeviceSize resourceBStart = resourceBOffset; + VkDeviceSize resourceBStartPage = resourceBStart & ~(pageSize - 1); + return resourceAEndPage == resourceBStartPage; +} + +enum VmaSuballocationType +{ + VMA_SUBALLOCATION_TYPE_FREE = 0, + VMA_SUBALLOCATION_TYPE_UNKNOWN = 1, + VMA_SUBALLOCATION_TYPE_BUFFER = 2, + VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN = 3, + VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR = 4, + VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL = 5, + VMA_SUBALLOCATION_TYPE_MAX_ENUM = 0x7FFFFFFF +}; + +/* +Returns true if given suballocation types could conflict and must respect +VkPhysicalDeviceLimits::bufferImageGranularity. They conflict if one is buffer +or linear image and another one is optimal image. If type is unknown, behave +conservatively. +*/ +static inline bool VmaIsBufferImageGranularityConflict( + VmaSuballocationType suballocType1, + VmaSuballocationType suballocType2) +{ + if(suballocType1 > suballocType2) + { + VMA_SWAP(suballocType1, suballocType2); + } + + switch(suballocType1) + { + case VMA_SUBALLOCATION_TYPE_FREE: + return false; + case VMA_SUBALLOCATION_TYPE_UNKNOWN: + return true; + case VMA_SUBALLOCATION_TYPE_BUFFER: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR || + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR: + return + suballocType2 == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL; + case VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL: + return false; + default: + VMA_ASSERT(0); + return true; + } +} + +static void VmaWriteMagicValue(void* pData, VkDeviceSize offset) +{ +#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION + uint32_t* pDst = (uint32_t*)((char*)pData + offset); + const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); + for(size_t i = 0; i < numberCount; ++i, ++pDst) + { + *pDst = VMA_CORRUPTION_DETECTION_MAGIC_VALUE; + } +#else + // no-op +#endif +} + +static bool VmaValidateMagicValue(const void* pData, VkDeviceSize offset) +{ +#if VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_DETECT_CORRUPTION + const uint32_t* pSrc = (const uint32_t*)((const char*)pData + offset); + const size_t numberCount = VMA_DEBUG_MARGIN / sizeof(uint32_t); + for(size_t i = 0; i < numberCount; ++i, ++pSrc) + { + if(*pSrc != VMA_CORRUPTION_DETECTION_MAGIC_VALUE) + { + return false; + } + } +#endif + return true; +} + +/* +Fills structure with parameters of an example buffer to be used for transfers +during GPU memory defragmentation. +*/ +static void VmaFillGpuDefragmentationBufferCreateInfo(VkBufferCreateInfo& outBufCreateInfo) +{ + memset(&outBufCreateInfo, 0, sizeof(outBufCreateInfo)); + outBufCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + outBufCreateInfo.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + outBufCreateInfo.size = (VkDeviceSize)VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE; // Example size. +} + +// Helper RAII class to lock a mutex in constructor and unlock it in destructor (at the end of scope). +struct VmaMutexLock +{ + VMA_CLASS_NO_COPY(VmaMutexLock) +public: + VmaMutexLock(VMA_MUTEX& mutex, bool useMutex = true) : + m_pMutex(useMutex ? &mutex : VMA_NULL) + { if(m_pMutex) { m_pMutex->Lock(); } } + ~VmaMutexLock() + { if(m_pMutex) { m_pMutex->Unlock(); } } +private: + VMA_MUTEX* m_pMutex; +}; + +// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for reading. +struct VmaMutexLockRead +{ + VMA_CLASS_NO_COPY(VmaMutexLockRead) +public: + VmaMutexLockRead(VMA_RW_MUTEX& mutex, bool useMutex) : + m_pMutex(useMutex ? &mutex : VMA_NULL) + { if(m_pMutex) { m_pMutex->LockRead(); } } + ~VmaMutexLockRead() { if(m_pMutex) { m_pMutex->UnlockRead(); } } +private: + VMA_RW_MUTEX* m_pMutex; +}; + +// Helper RAII class to lock a RW mutex in constructor and unlock it in destructor (at the end of scope), for writing. +struct VmaMutexLockWrite +{ + VMA_CLASS_NO_COPY(VmaMutexLockWrite) +public: + VmaMutexLockWrite(VMA_RW_MUTEX& mutex, bool useMutex) : + m_pMutex(useMutex ? &mutex : VMA_NULL) + { if(m_pMutex) { m_pMutex->LockWrite(); } } + ~VmaMutexLockWrite() { if(m_pMutex) { m_pMutex->UnlockWrite(); } } +private: + VMA_RW_MUTEX* m_pMutex; +}; + +#if VMA_DEBUG_GLOBAL_MUTEX + static VMA_MUTEX gDebugGlobalMutex; + #define VMA_DEBUG_GLOBAL_MUTEX_LOCK VmaMutexLock debugGlobalMutexLock(gDebugGlobalMutex, true); +#else + #define VMA_DEBUG_GLOBAL_MUTEX_LOCK +#endif + +// Minimum size of a free suballocation to register it in the free suballocation collection. +static const VkDeviceSize VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER = 16; + +/* +Performs binary search and returns iterator to first element that is greater or +equal to (key), according to comparison (cmp). + +Cmp should return true if first argument is less than second argument. + +Returned value is the found element, if present in the collection or place where +new element with value (key) should be inserted. +*/ +template +static IterT VmaBinaryFindFirstNotLess(IterT beg, IterT end, const KeyT &key, const CmpLess& cmp) +{ + size_t down = 0, up = (end - beg); + while(down < up) + { + const size_t mid = (down + up) / 2; + if(cmp(*(beg+mid), key)) + { + down = mid + 1; + } + else + { + up = mid; + } + } + return beg + down; +} + +template +IterT VmaBinaryFindSorted(const IterT& beg, const IterT& end, const KeyT& value, const CmpLess& cmp) +{ + IterT it = VmaBinaryFindFirstNotLess( + beg, end, value, cmp); + if(it == end || + (!cmp(*it, value) && !cmp(value, *it))) + { + return it; + } + return end; +} + +/* +Returns true if all pointers in the array are not-null and unique. +Warning! O(n^2) complexity. Use only inside VMA_HEAVY_ASSERT. +T must be pointer type, e.g. VmaAllocation, VmaPool. +*/ +template +static bool VmaValidatePointerArray(uint32_t count, const T* arr) +{ + for(uint32_t i = 0; i < count; ++i) + { + const T iPtr = arr[i]; + if(iPtr == VMA_NULL) + { + return false; + } + for(uint32_t j = i + 1; j < count; ++j) + { + if(iPtr == arr[j]) + { + return false; + } + } + } + return true; +} + +template +static inline void VmaPnextChainPushFront(MainT* mainStruct, NewT* newStruct) +{ + newStruct->pNext = mainStruct->pNext; + mainStruct->pNext = newStruct; +} + +//////////////////////////////////////////////////////////////////////////////// +// Memory allocation + +static void* VmaMalloc(const VkAllocationCallbacks* pAllocationCallbacks, size_t size, size_t alignment) +{ + void* result = VMA_NULL; + if((pAllocationCallbacks != VMA_NULL) && + (pAllocationCallbacks->pfnAllocation != VMA_NULL)) + { + result = (*pAllocationCallbacks->pfnAllocation)( + pAllocationCallbacks->pUserData, + size, + alignment, + VK_SYSTEM_ALLOCATION_SCOPE_OBJECT); + } + else + { + result = VMA_SYSTEM_ALIGNED_MALLOC(size, alignment); + } + VMA_ASSERT(result != VMA_NULL && "CPU memory allocation failed."); + return result; +} + +static void VmaFree(const VkAllocationCallbacks* pAllocationCallbacks, void* ptr) +{ + if((pAllocationCallbacks != VMA_NULL) && + (pAllocationCallbacks->pfnFree != VMA_NULL)) + { + (*pAllocationCallbacks->pfnFree)(pAllocationCallbacks->pUserData, ptr); + } + else + { + VMA_SYSTEM_FREE(ptr); + } +} + +template +static T* VmaAllocate(const VkAllocationCallbacks* pAllocationCallbacks) +{ + return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T), VMA_ALIGN_OF(T)); +} + +template +static T* VmaAllocateArray(const VkAllocationCallbacks* pAllocationCallbacks, size_t count) +{ + return (T*)VmaMalloc(pAllocationCallbacks, sizeof(T) * count, VMA_ALIGN_OF(T)); +} + +#define vma_new(allocator, type) new(VmaAllocate(allocator))(type) + +#define vma_new_array(allocator, type, count) new(VmaAllocateArray((allocator), (count)))(type) + +template +static void vma_delete(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr) +{ + ptr->~T(); + VmaFree(pAllocationCallbacks, ptr); +} + +template +static void vma_delete_array(const VkAllocationCallbacks* pAllocationCallbacks, T* ptr, size_t count) +{ + if(ptr != VMA_NULL) + { + for(size_t i = count; i--; ) + { + ptr[i].~T(); + } + VmaFree(pAllocationCallbacks, ptr); + } +} + +static char* VmaCreateStringCopy(const VkAllocationCallbacks* allocs, const char* srcStr) +{ + if(srcStr != VMA_NULL) + { + const size_t len = strlen(srcStr); + char* const result = vma_new_array(allocs, char, len + 1); + memcpy(result, srcStr, len + 1); + return result; + } + else + { + return VMA_NULL; + } +} + +static void VmaFreeString(const VkAllocationCallbacks* allocs, char* str) +{ + if(str != VMA_NULL) + { + const size_t len = strlen(str); + vma_delete_array(allocs, str, len + 1); + } +} + +// STL-compatible allocator. +template +class VmaStlAllocator +{ +public: + const VkAllocationCallbacks* const m_pCallbacks; + typedef T value_type; + + VmaStlAllocator(const VkAllocationCallbacks* pCallbacks) : m_pCallbacks(pCallbacks) { } + template VmaStlAllocator(const VmaStlAllocator& src) : m_pCallbacks(src.m_pCallbacks) { } + + T* allocate(size_t n) { return VmaAllocateArray(m_pCallbacks, n); } + void deallocate(T* p, size_t n) { VmaFree(m_pCallbacks, p); } + + template + bool operator==(const VmaStlAllocator& rhs) const + { + return m_pCallbacks == rhs.m_pCallbacks; + } + template + bool operator!=(const VmaStlAllocator& rhs) const + { + return m_pCallbacks != rhs.m_pCallbacks; + } + + VmaStlAllocator& operator=(const VmaStlAllocator& x) = delete; +}; + +#if VMA_USE_STL_VECTOR + +#define VmaVector std::vector + +template +static void VmaVectorInsert(std::vector& vec, size_t index, const T& item) +{ + vec.insert(vec.begin() + index, item); +} + +template +static void VmaVectorRemove(std::vector& vec, size_t index) +{ + vec.erase(vec.begin() + index); +} + +#else // #if VMA_USE_STL_VECTOR + +/* Class with interface compatible with subset of std::vector. +T must be POD because constructors and destructors are not called and memcpy is +used for these objects. */ +template +class VmaVector +{ +public: + typedef T value_type; + + VmaVector(const AllocatorT& allocator) : + m_Allocator(allocator), + m_pArray(VMA_NULL), + m_Count(0), + m_Capacity(0) + { + } + + VmaVector(size_t count, const AllocatorT& allocator) : + m_Allocator(allocator), + m_pArray(count ? (T*)VmaAllocateArray(allocator.m_pCallbacks, count) : VMA_NULL), + m_Count(count), + m_Capacity(count) + { + } + + // This version of the constructor is here for compatibility with pre-C++14 std::vector. + // value is unused. + VmaVector(size_t count, const T& value, const AllocatorT& allocator) + : VmaVector(count, allocator) {} + + VmaVector(const VmaVector& src) : + m_Allocator(src.m_Allocator), + m_pArray(src.m_Count ? (T*)VmaAllocateArray(src.m_Allocator.m_pCallbacks, src.m_Count) : VMA_NULL), + m_Count(src.m_Count), + m_Capacity(src.m_Count) + { + if(m_Count != 0) + { + memcpy(m_pArray, src.m_pArray, m_Count * sizeof(T)); + } + } + + ~VmaVector() + { + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + } + + VmaVector& operator=(const VmaVector& rhs) + { + if(&rhs != this) + { + resize(rhs.m_Count); + if(m_Count != 0) + { + memcpy(m_pArray, rhs.m_pArray, m_Count * sizeof(T)); + } + } + return *this; + } + + bool empty() const { return m_Count == 0; } + size_t size() const { return m_Count; } + T* data() { return m_pArray; } + const T* data() const { return m_pArray; } + + T& operator[](size_t index) + { + VMA_HEAVY_ASSERT(index < m_Count); + return m_pArray[index]; + } + const T& operator[](size_t index) const + { + VMA_HEAVY_ASSERT(index < m_Count); + return m_pArray[index]; + } + + T& front() + { + VMA_HEAVY_ASSERT(m_Count > 0); + return m_pArray[0]; + } + const T& front() const + { + VMA_HEAVY_ASSERT(m_Count > 0); + return m_pArray[0]; + } + T& back() + { + VMA_HEAVY_ASSERT(m_Count > 0); + return m_pArray[m_Count - 1]; + } + const T& back() const + { + VMA_HEAVY_ASSERT(m_Count > 0); + return m_pArray[m_Count - 1]; + } + + void reserve(size_t newCapacity, bool freeMemory = false) + { + newCapacity = VMA_MAX(newCapacity, m_Count); + + if((newCapacity < m_Capacity) && !freeMemory) + { + newCapacity = m_Capacity; + } + + if(newCapacity != m_Capacity) + { + T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator, newCapacity) : VMA_NULL; + if(m_Count != 0) + { + memcpy(newArray, m_pArray, m_Count * sizeof(T)); + } + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + m_Capacity = newCapacity; + m_pArray = newArray; + } + } + + void resize(size_t newCount, bool freeMemory = false) + { + size_t newCapacity = m_Capacity; + if(newCount > m_Capacity) + { + newCapacity = VMA_MAX(newCount, VMA_MAX(m_Capacity * 3 / 2, (size_t)8)); + } + else if(freeMemory) + { + newCapacity = newCount; + } + + if(newCapacity != m_Capacity) + { + T* const newArray = newCapacity ? VmaAllocateArray(m_Allocator.m_pCallbacks, newCapacity) : VMA_NULL; + const size_t elementsToCopy = VMA_MIN(m_Count, newCount); + if(elementsToCopy != 0) + { + memcpy(newArray, m_pArray, elementsToCopy * sizeof(T)); + } + VmaFree(m_Allocator.m_pCallbacks, m_pArray); + m_Capacity = newCapacity; + m_pArray = newArray; + } + + m_Count = newCount; + } + + void clear(bool freeMemory = false) + { + resize(0, freeMemory); + } + + void insert(size_t index, const T& src) + { + VMA_HEAVY_ASSERT(index <= m_Count); + const size_t oldCount = size(); + resize(oldCount + 1); + if(index < oldCount) + { + memmove(m_pArray + (index + 1), m_pArray + index, (oldCount - index) * sizeof(T)); + } + m_pArray[index] = src; + } + + void remove(size_t index) + { + VMA_HEAVY_ASSERT(index < m_Count); + const size_t oldCount = size(); + if(index < oldCount - 1) + { + memmove(m_pArray + index, m_pArray + (index + 1), (oldCount - index - 1) * sizeof(T)); + } + resize(oldCount - 1); + } + + void push_back(const T& src) + { + const size_t newIndex = size(); + resize(newIndex + 1); + m_pArray[newIndex] = src; + } + + void pop_back() + { + VMA_HEAVY_ASSERT(m_Count > 0); + resize(size() - 1); + } + + void push_front(const T& src) + { + insert(0, src); + } + + void pop_front() + { + VMA_HEAVY_ASSERT(m_Count > 0); + remove(0); + } + + typedef T* iterator; + + iterator begin() { return m_pArray; } + iterator end() { return m_pArray + m_Count; } + +private: + AllocatorT m_Allocator; + T* m_pArray; + size_t m_Count; + size_t m_Capacity; +}; + +template +static void VmaVectorInsert(VmaVector& vec, size_t index, const T& item) +{ + vec.insert(index, item); +} + +template +static void VmaVectorRemove(VmaVector& vec, size_t index) +{ + vec.remove(index); +} + +#endif // #if VMA_USE_STL_VECTOR + +template +size_t VmaVectorInsertSorted(VectorT& vector, const typename VectorT::value_type& value) +{ + const size_t indexToInsert = VmaBinaryFindFirstNotLess( + vector.data(), + vector.data() + vector.size(), + value, + CmpLess()) - vector.data(); + VmaVectorInsert(vector, indexToInsert, value); + return indexToInsert; +} + +template +bool VmaVectorRemoveSorted(VectorT& vector, const typename VectorT::value_type& value) +{ + CmpLess comparator; + typename VectorT::iterator it = VmaBinaryFindFirstNotLess( + vector.begin(), + vector.end(), + value, + comparator); + if((it != vector.end()) && !comparator(*it, value) && !comparator(value, *it)) + { + size_t indexToRemove = it - vector.begin(); + VmaVectorRemove(vector, indexToRemove); + return true; + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// +// class VmaSmallVector + +/* +This is a vector (a variable-sized array), optimized for the case when the array is small. + +It contains some number of elements in-place, which allows it to avoid heap allocation +when the actual number of elements is below that threshold. This allows normal "small" +cases to be fast without losing generality for large inputs. +*/ + +template +class VmaSmallVector +{ +public: + typedef T value_type; + + VmaSmallVector(const AllocatorT& allocator) : + m_Count(0), + m_DynamicArray(allocator) + { + } + VmaSmallVector(size_t count, const AllocatorT& allocator) : + m_Count(count), + m_DynamicArray(count > N ? count : 0, allocator) + { + } + template + VmaSmallVector(const VmaSmallVector& src) = delete; + template + VmaSmallVector& operator=(const VmaSmallVector& rhs) = delete; + + bool empty() const { return m_Count == 0; } + size_t size() const { return m_Count; } + T* data() { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } + const T* data() const { return m_Count > N ? m_DynamicArray.data() : m_StaticArray; } + + T& operator[](size_t index) + { + VMA_HEAVY_ASSERT(index < m_Count); + return data()[index]; + } + const T& operator[](size_t index) const + { + VMA_HEAVY_ASSERT(index < m_Count); + return data()[index]; + } + + T& front() + { + VMA_HEAVY_ASSERT(m_Count > 0); + return data()[0]; + } + const T& front() const + { + VMA_HEAVY_ASSERT(m_Count > 0); + return data()[0]; + } + T& back() + { + VMA_HEAVY_ASSERT(m_Count > 0); + return data()[m_Count - 1]; + } + const T& back() const + { + VMA_HEAVY_ASSERT(m_Count > 0); + return data()[m_Count - 1]; + } + + void resize(size_t newCount, bool freeMemory = false) + { + if(newCount > N && m_Count > N) + { + // Any direction, staying in m_DynamicArray + m_DynamicArray.resize(newCount, freeMemory); + } + else if(newCount > N && m_Count <= N) + { + // Growing, moving from m_StaticArray to m_DynamicArray + m_DynamicArray.resize(newCount, freeMemory); + if(m_Count > 0) + { + memcpy(m_DynamicArray.data(), m_StaticArray, m_Count * sizeof(T)); + } + } + else if(newCount <= N && m_Count > N) + { + // Shrinking, moving from m_DynamicArray to m_StaticArray + if(newCount > 0) + { + memcpy(m_StaticArray, m_DynamicArray.data(), newCount * sizeof(T)); + } + m_DynamicArray.resize(0, freeMemory); + } + else + { + // Any direction, staying in m_StaticArray - nothing to do here + } + m_Count = newCount; + } + + void clear(bool freeMemory = false) + { + m_DynamicArray.clear(freeMemory); + m_Count = 0; + } + + void insert(size_t index, const T& src) + { + VMA_HEAVY_ASSERT(index <= m_Count); + const size_t oldCount = size(); + resize(oldCount + 1); + T* const dataPtr = data(); + if(index < oldCount) + { + // I know, this could be more optimal for case where memmove can be memcpy directly from m_StaticArray to m_DynamicArray. + memmove(dataPtr + (index + 1), dataPtr + index, (oldCount - index) * sizeof(T)); + } + dataPtr[index] = src; + } + + void remove(size_t index) + { + VMA_HEAVY_ASSERT(index < m_Count); + const size_t oldCount = size(); + if(index < oldCount - 1) + { + // I know, this could be more optimal for case where memmove can be memcpy directly from m_DynamicArray to m_StaticArray. + T* const dataPtr = data(); + memmove(dataPtr + index, dataPtr + (index + 1), (oldCount - index - 1) * sizeof(T)); + } + resize(oldCount - 1); + } + + void push_back(const T& src) + { + const size_t newIndex = size(); + resize(newIndex + 1); + data()[newIndex] = src; + } + + void pop_back() + { + VMA_HEAVY_ASSERT(m_Count > 0); + resize(size() - 1); + } + + void push_front(const T& src) + { + insert(0, src); + } + + void pop_front() + { + VMA_HEAVY_ASSERT(m_Count > 0); + remove(0); + } + + typedef T* iterator; + + iterator begin() { return data(); } + iterator end() { return data() + m_Count; } + +private: + size_t m_Count; + T m_StaticArray[N]; // Used when m_Size <= N + VmaVector m_DynamicArray; // Used when m_Size > N +}; + +//////////////////////////////////////////////////////////////////////////////// +// class VmaPoolAllocator + +/* +Allocator for objects of type T using a list of arrays (pools) to speed up +allocation. Number of elements that can be allocated is not bounded because +allocator can create multiple blocks. +*/ +template +class VmaPoolAllocator +{ + VMA_CLASS_NO_COPY(VmaPoolAllocator) +public: + VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity); + ~VmaPoolAllocator(); + template T* Alloc(Types... args); + void Free(T* ptr); + +private: + union Item + { + uint32_t NextFreeIndex; + alignas(T) char Value[sizeof(T)]; + }; + + struct ItemBlock + { + Item* pItems; + uint32_t Capacity; + uint32_t FirstFreeIndex; + }; + + const VkAllocationCallbacks* m_pAllocationCallbacks; + const uint32_t m_FirstBlockCapacity; + VmaVector< ItemBlock, VmaStlAllocator > m_ItemBlocks; + + ItemBlock& CreateNewBlock(); +}; + +template +VmaPoolAllocator::VmaPoolAllocator(const VkAllocationCallbacks* pAllocationCallbacks, uint32_t firstBlockCapacity) : + m_pAllocationCallbacks(pAllocationCallbacks), + m_FirstBlockCapacity(firstBlockCapacity), + m_ItemBlocks(VmaStlAllocator(pAllocationCallbacks)) +{ + VMA_ASSERT(m_FirstBlockCapacity > 1); +} + +template +VmaPoolAllocator::~VmaPoolAllocator() +{ + for(size_t i = m_ItemBlocks.size(); i--; ) + vma_delete_array(m_pAllocationCallbacks, m_ItemBlocks[i].pItems, m_ItemBlocks[i].Capacity); + m_ItemBlocks.clear(); +} + +template +template T* VmaPoolAllocator::Alloc(Types... args) +{ + for(size_t i = m_ItemBlocks.size(); i--; ) + { + ItemBlock& block = m_ItemBlocks[i]; + // This block has some free items: Use first one. + if(block.FirstFreeIndex != UINT32_MAX) + { + Item* const pItem = &block.pItems[block.FirstFreeIndex]; + block.FirstFreeIndex = pItem->NextFreeIndex; + T* result = (T*)&pItem->Value; + new(result)T(std::forward(args)...); // Explicit constructor call. + return result; + } + } + + // No block has free item: Create new one and use it. + ItemBlock& newBlock = CreateNewBlock(); + Item* const pItem = &newBlock.pItems[0]; + newBlock.FirstFreeIndex = pItem->NextFreeIndex; + T* result = (T*)&pItem->Value; + new(result)T(std::forward(args)...); // Explicit constructor call. + return result; +} + +template +void VmaPoolAllocator::Free(T* ptr) +{ + // Search all memory blocks to find ptr. + for(size_t i = m_ItemBlocks.size(); i--; ) + { + ItemBlock& block = m_ItemBlocks[i]; + + // Casting to union. + Item* pItemPtr; + memcpy(&pItemPtr, &ptr, sizeof(pItemPtr)); + + // Check if pItemPtr is in address range of this block. + if((pItemPtr >= block.pItems) && (pItemPtr < block.pItems + block.Capacity)) + { + ptr->~T(); // Explicit destructor call. + const uint32_t index = static_cast(pItemPtr - block.pItems); + pItemPtr->NextFreeIndex = block.FirstFreeIndex; + block.FirstFreeIndex = index; + return; + } + } + VMA_ASSERT(0 && "Pointer doesn't belong to this memory pool."); +} + +template +typename VmaPoolAllocator::ItemBlock& VmaPoolAllocator::CreateNewBlock() +{ + const uint32_t newBlockCapacity = m_ItemBlocks.empty() ? + m_FirstBlockCapacity : m_ItemBlocks.back().Capacity * 3 / 2; + + const ItemBlock newBlock = { + vma_new_array(m_pAllocationCallbacks, Item, newBlockCapacity), + newBlockCapacity, + 0 }; + + m_ItemBlocks.push_back(newBlock); + + // Setup singly-linked list of all free items in this block. + for(uint32_t i = 0; i < newBlockCapacity - 1; ++i) + newBlock.pItems[i].NextFreeIndex = i + 1; + newBlock.pItems[newBlockCapacity - 1].NextFreeIndex = UINT32_MAX; + return m_ItemBlocks.back(); +} + +//////////////////////////////////////////////////////////////////////////////// +// class VmaRawList, VmaList + +#if VMA_USE_STL_LIST + +#define VmaList std::list + +#else // #if VMA_USE_STL_LIST + +template +struct VmaListItem +{ + VmaListItem* pPrev; + VmaListItem* pNext; + T Value; +}; + +// Doubly linked list. +template +class VmaRawList +{ + VMA_CLASS_NO_COPY(VmaRawList) +public: + typedef VmaListItem ItemType; + + VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks); + ~VmaRawList(); + void Clear(); + + size_t GetCount() const { return m_Count; } + bool IsEmpty() const { return m_Count == 0; } + + ItemType* Front() { return m_pFront; } + const ItemType* Front() const { return m_pFront; } + ItemType* Back() { return m_pBack; } + const ItemType* Back() const { return m_pBack; } + + ItemType* PushBack(); + ItemType* PushFront(); + ItemType* PushBack(const T& value); + ItemType* PushFront(const T& value); + void PopBack(); + void PopFront(); + + // Item can be null - it means PushBack. + ItemType* InsertBefore(ItemType* pItem); + // Item can be null - it means PushFront. + ItemType* InsertAfter(ItemType* pItem); + + ItemType* InsertBefore(ItemType* pItem, const T& value); + ItemType* InsertAfter(ItemType* pItem, const T& value); + + void Remove(ItemType* pItem); + +private: + const VkAllocationCallbacks* const m_pAllocationCallbacks; + VmaPoolAllocator m_ItemAllocator; + ItemType* m_pFront; + ItemType* m_pBack; + size_t m_Count; +}; + +template +VmaRawList::VmaRawList(const VkAllocationCallbacks* pAllocationCallbacks) : + m_pAllocationCallbacks(pAllocationCallbacks), + m_ItemAllocator(pAllocationCallbacks, 128), + m_pFront(VMA_NULL), + m_pBack(VMA_NULL), + m_Count(0) +{ +} + +template +VmaRawList::~VmaRawList() +{ + // Intentionally not calling Clear, because that would be unnecessary + // computations to return all items to m_ItemAllocator as free. +} + +template +void VmaRawList::Clear() +{ + if(IsEmpty() == false) + { + ItemType* pItem = m_pBack; + while(pItem != VMA_NULL) + { + ItemType* const pPrevItem = pItem->pPrev; + m_ItemAllocator.Free(pItem); + pItem = pPrevItem; + } + m_pFront = VMA_NULL; + m_pBack = VMA_NULL; + m_Count = 0; + } +} + +template +VmaListItem* VmaRawList::PushBack() +{ + ItemType* const pNewItem = m_ItemAllocator.Alloc(); + pNewItem->pNext = VMA_NULL; + if(IsEmpty()) + { + pNewItem->pPrev = VMA_NULL; + m_pFront = pNewItem; + m_pBack = pNewItem; + m_Count = 1; + } + else + { + pNewItem->pPrev = m_pBack; + m_pBack->pNext = pNewItem; + m_pBack = pNewItem; + ++m_Count; + } + return pNewItem; +} + +template +VmaListItem* VmaRawList::PushFront() +{ + ItemType* const pNewItem = m_ItemAllocator.Alloc(); + pNewItem->pPrev = VMA_NULL; + if(IsEmpty()) + { + pNewItem->pNext = VMA_NULL; + m_pFront = pNewItem; + m_pBack = pNewItem; + m_Count = 1; + } + else + { + pNewItem->pNext = m_pFront; + m_pFront->pPrev = pNewItem; + m_pFront = pNewItem; + ++m_Count; + } + return pNewItem; +} + +template +VmaListItem* VmaRawList::PushBack(const T& value) +{ + ItemType* const pNewItem = PushBack(); + pNewItem->Value = value; + return pNewItem; +} + +template +VmaListItem* VmaRawList::PushFront(const T& value) +{ + ItemType* const pNewItem = PushFront(); + pNewItem->Value = value; + return pNewItem; +} + +template +void VmaRawList::PopBack() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const pBackItem = m_pBack; + ItemType* const pPrevItem = pBackItem->pPrev; + if(pPrevItem != VMA_NULL) + { + pPrevItem->pNext = VMA_NULL; + } + m_pBack = pPrevItem; + m_ItemAllocator.Free(pBackItem); + --m_Count; +} + +template +void VmaRawList::PopFront() +{ + VMA_HEAVY_ASSERT(m_Count > 0); + ItemType* const pFrontItem = m_pFront; + ItemType* const pNextItem = pFrontItem->pNext; + if(pNextItem != VMA_NULL) + { + pNextItem->pPrev = VMA_NULL; + } + m_pFront = pNextItem; + m_ItemAllocator.Free(pFrontItem); + --m_Count; +} + +template +void VmaRawList::Remove(ItemType* pItem) +{ + VMA_HEAVY_ASSERT(pItem != VMA_NULL); + VMA_HEAVY_ASSERT(m_Count > 0); + + if(pItem->pPrev != VMA_NULL) + { + pItem->pPrev->pNext = pItem->pNext; + } + else + { + VMA_HEAVY_ASSERT(m_pFront == pItem); + m_pFront = pItem->pNext; + } + + if(pItem->pNext != VMA_NULL) + { + pItem->pNext->pPrev = pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(m_pBack == pItem); + m_pBack = pItem->pPrev; + } + + m_ItemAllocator.Free(pItem); + --m_Count; +} + +template +VmaListItem* VmaRawList::InsertBefore(ItemType* pItem) +{ + if(pItem != VMA_NULL) + { + ItemType* const prevItem = pItem->pPrev; + ItemType* const newItem = m_ItemAllocator.Alloc(); + newItem->pPrev = prevItem; + newItem->pNext = pItem; + pItem->pPrev = newItem; + if(prevItem != VMA_NULL) + { + prevItem->pNext = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_pFront == pItem); + m_pFront = newItem; + } + ++m_Count; + return newItem; + } + else + return PushBack(); +} + +template +VmaListItem* VmaRawList::InsertAfter(ItemType* pItem) +{ + if(pItem != VMA_NULL) + { + ItemType* const nextItem = pItem->pNext; + ItemType* const newItem = m_ItemAllocator.Alloc(); + newItem->pNext = nextItem; + newItem->pPrev = pItem; + pItem->pNext = newItem; + if(nextItem != VMA_NULL) + { + nextItem->pPrev = newItem; + } + else + { + VMA_HEAVY_ASSERT(m_pBack == pItem); + m_pBack = newItem; + } + ++m_Count; + return newItem; + } + else + return PushFront(); +} + +template +VmaListItem* VmaRawList::InsertBefore(ItemType* pItem, const T& value) +{ + ItemType* const newItem = InsertBefore(pItem); + newItem->Value = value; + return newItem; +} + +template +VmaListItem* VmaRawList::InsertAfter(ItemType* pItem, const T& value) +{ + ItemType* const newItem = InsertAfter(pItem); + newItem->Value = value; + return newItem; +} + +template +class VmaList +{ + VMA_CLASS_NO_COPY(VmaList) +public: + class iterator + { + public: + iterator() : + m_pList(VMA_NULL), + m_pItem(VMA_NULL) + { + } + + T& operator*() const + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + return m_pItem->Value; + } + T* operator->() const + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + return &m_pItem->Value; + } + + iterator& operator++() + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + m_pItem = m_pItem->pNext; + return *this; + } + iterator& operator--() + { + if(m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Back(); + } + return *this; + } + + iterator operator++(int) + { + iterator result = *this; + ++*this; + return result; + } + iterator operator--(int) + { + iterator result = *this; + --*this; + return result; + } + + bool operator==(const iterator& rhs) const + { + VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); + return m_pItem == rhs.m_pItem; + } + bool operator!=(const iterator& rhs) const + { + VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); + return m_pItem != rhs.m_pItem; + } + + private: + VmaRawList* m_pList; + VmaListItem* m_pItem; + + iterator(VmaRawList* pList, VmaListItem* pItem) : + m_pList(pList), + m_pItem(pItem) + { + } + + friend class VmaList; + }; + + class const_iterator + { + public: + const_iterator() : + m_pList(VMA_NULL), + m_pItem(VMA_NULL) + { + } + + const_iterator(const iterator& src) : + m_pList(src.m_pList), + m_pItem(src.m_pItem) + { + } + + const T& operator*() const + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + return m_pItem->Value; + } + const T* operator->() const + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + return &m_pItem->Value; + } + + const_iterator& operator++() + { + VMA_HEAVY_ASSERT(m_pItem != VMA_NULL); + m_pItem = m_pItem->pNext; + return *this; + } + const_iterator& operator--() + { + if(m_pItem != VMA_NULL) + { + m_pItem = m_pItem->pPrev; + } + else + { + VMA_HEAVY_ASSERT(!m_pList->IsEmpty()); + m_pItem = m_pList->Back(); + } + return *this; + } + + const_iterator operator++(int) + { + const_iterator result = *this; + ++*this; + return result; + } + const_iterator operator--(int) + { + const_iterator result = *this; + --*this; + return result; + } + + bool operator==(const const_iterator& rhs) const + { + VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); + return m_pItem == rhs.m_pItem; + } + bool operator!=(const const_iterator& rhs) const + { + VMA_HEAVY_ASSERT(m_pList == rhs.m_pList); + return m_pItem != rhs.m_pItem; + } + + private: + const_iterator(const VmaRawList* pList, const VmaListItem* pItem) : + m_pList(pList), + m_pItem(pItem) + { + } + + const VmaRawList* m_pList; + const VmaListItem* m_pItem; + + friend class VmaList; + }; + + VmaList(const AllocatorT& allocator) : m_RawList(allocator.m_pCallbacks) { } + + bool empty() const { return m_RawList.IsEmpty(); } + size_t size() const { return m_RawList.GetCount(); } + + iterator begin() { return iterator(&m_RawList, m_RawList.Front()); } + iterator end() { return iterator(&m_RawList, VMA_NULL); } + + const_iterator cbegin() const { return const_iterator(&m_RawList, m_RawList.Front()); } + const_iterator cend() const { return const_iterator(&m_RawList, VMA_NULL); } + + void clear() { m_RawList.Clear(); } + void push_back(const T& value) { m_RawList.PushBack(value); } + void erase(iterator it) { m_RawList.Remove(it.m_pItem); } + iterator insert(iterator it, const T& value) { return iterator(&m_RawList, m_RawList.InsertBefore(it.m_pItem, value)); } + +private: + VmaRawList m_RawList; +}; + +#endif // #if VMA_USE_STL_LIST + +//////////////////////////////////////////////////////////////////////////////// +// class VmaMap + +// Unused in this version. +#if 0 + +#if VMA_USE_STL_UNORDERED_MAP + +#define VmaPair std::pair + +#define VMA_MAP_TYPE(KeyT, ValueT) \ + std::unordered_map< KeyT, ValueT, std::hash, std::equal_to, VmaStlAllocator< std::pair > > + +#else // #if VMA_USE_STL_UNORDERED_MAP + +template +struct VmaPair +{ + T1 first; + T2 second; + + VmaPair() : first(), second() { } + VmaPair(const T1& firstSrc, const T2& secondSrc) : first(firstSrc), second(secondSrc) { } +}; + +/* Class compatible with subset of interface of std::unordered_map. +KeyT, ValueT must be POD because they will be stored in VmaVector. +*/ +template +class VmaMap +{ +public: + typedef VmaPair PairType; + typedef PairType* iterator; + + VmaMap(const VmaStlAllocator& allocator) : m_Vector(allocator) { } + + iterator begin() { return m_Vector.begin(); } + iterator end() { return m_Vector.end(); } + + void insert(const PairType& pair); + iterator find(const KeyT& key); + void erase(iterator it); + +private: + VmaVector< PairType, VmaStlAllocator > m_Vector; +}; + +#define VMA_MAP_TYPE(KeyT, ValueT) VmaMap + +template +struct VmaPairFirstLess +{ + bool operator()(const VmaPair& lhs, const VmaPair& rhs) const + { + return lhs.first < rhs.first; + } + bool operator()(const VmaPair& lhs, const FirstT& rhsFirst) const + { + return lhs.first < rhsFirst; + } +}; + +template +void VmaMap::insert(const PairType& pair) +{ + const size_t indexToInsert = VmaBinaryFindFirstNotLess( + m_Vector.data(), + m_Vector.data() + m_Vector.size(), + pair, + VmaPairFirstLess()) - m_Vector.data(); + VmaVectorInsert(m_Vector, indexToInsert, pair); +} + +template +VmaPair* VmaMap::find(const KeyT& key) +{ + PairType* it = VmaBinaryFindFirstNotLess( + m_Vector.data(), + m_Vector.data() + m_Vector.size(), + key, + VmaPairFirstLess()); + if((it != m_Vector.end()) && (it->first == key)) + { + return it; + } + else + { + return m_Vector.end(); + } +} + +template +void VmaMap::erase(iterator it) +{ + VmaVectorRemove(m_Vector, it - m_Vector.begin()); +} + +#endif // #if VMA_USE_STL_UNORDERED_MAP + +#endif // #if 0 + +//////////////////////////////////////////////////////////////////////////////// + +class VmaDeviceMemoryBlock; + +enum VMA_CACHE_OPERATION { VMA_CACHE_FLUSH, VMA_CACHE_INVALIDATE }; + +struct VmaAllocation_T +{ +private: + static const uint8_t MAP_COUNT_FLAG_PERSISTENT_MAP = 0x80; + + enum FLAGS + { + FLAG_USER_DATA_STRING = 0x01, + }; + +public: + enum ALLOCATION_TYPE + { + ALLOCATION_TYPE_NONE, + ALLOCATION_TYPE_BLOCK, + ALLOCATION_TYPE_DEDICATED, + }; + + /* + This struct is allocated using VmaPoolAllocator. + */ + + VmaAllocation_T(uint32_t currentFrameIndex, bool userDataString) : + m_Alignment{1}, + m_Size{0}, + m_pUserData{VMA_NULL}, + m_LastUseFrameIndex{currentFrameIndex}, + m_MemoryTypeIndex{0}, + m_Type{(uint8_t)ALLOCATION_TYPE_NONE}, + m_SuballocationType{(uint8_t)VMA_SUBALLOCATION_TYPE_UNKNOWN}, + m_MapCount{0}, + m_Flags{userDataString ? (uint8_t)FLAG_USER_DATA_STRING : (uint8_t)0} + { +#if VMA_STATS_STRING_ENABLED + m_CreationFrameIndex = currentFrameIndex; + m_BufferImageUsage = 0; +#endif + } + + ~VmaAllocation_T() + { + VMA_ASSERT((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) == 0 && "Allocation was not unmapped before destruction."); + + // Check if owned string was freed. + VMA_ASSERT(m_pUserData == VMA_NULL); + } + + void InitBlockAllocation( + VmaDeviceMemoryBlock* block, + VkDeviceSize offset, + VkDeviceSize alignment, + VkDeviceSize size, + uint32_t memoryTypeIndex, + VmaSuballocationType suballocationType, + bool mapped, + bool canBecomeLost) + { + VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); + VMA_ASSERT(block != VMA_NULL); + m_Type = (uint8_t)ALLOCATION_TYPE_BLOCK; + m_Alignment = alignment; + m_Size = size; + m_MemoryTypeIndex = memoryTypeIndex; + m_MapCount = mapped ? MAP_COUNT_FLAG_PERSISTENT_MAP : 0; + m_SuballocationType = (uint8_t)suballocationType; + m_BlockAllocation.m_Block = block; + m_BlockAllocation.m_Offset = offset; + m_BlockAllocation.m_CanBecomeLost = canBecomeLost; + } + + void InitLost() + { + VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); + VMA_ASSERT(m_LastUseFrameIndex.load() == VMA_FRAME_INDEX_LOST); + m_Type = (uint8_t)ALLOCATION_TYPE_BLOCK; + m_MemoryTypeIndex = 0; + m_BlockAllocation.m_Block = VMA_NULL; + m_BlockAllocation.m_Offset = 0; + m_BlockAllocation.m_CanBecomeLost = true; + } + + void ChangeBlockAllocation( + VmaAllocator hAllocator, + VmaDeviceMemoryBlock* block, + VkDeviceSize offset); + + void ChangeOffset(VkDeviceSize newOffset); + + // pMappedData not null means allocation is created with MAPPED flag. + void InitDedicatedAllocation( + uint32_t memoryTypeIndex, + VkDeviceMemory hMemory, + VmaSuballocationType suballocationType, + void* pMappedData, + VkDeviceSize size) + { + VMA_ASSERT(m_Type == ALLOCATION_TYPE_NONE); + VMA_ASSERT(hMemory != VK_NULL_HANDLE); + m_Type = (uint8_t)ALLOCATION_TYPE_DEDICATED; + m_Alignment = 0; + m_Size = size; + m_MemoryTypeIndex = memoryTypeIndex; + m_SuballocationType = (uint8_t)suballocationType; + m_MapCount = (pMappedData != VMA_NULL) ? MAP_COUNT_FLAG_PERSISTENT_MAP : 0; + m_DedicatedAllocation.m_hMemory = hMemory; + m_DedicatedAllocation.m_pMappedData = pMappedData; + } + + ALLOCATION_TYPE GetType() const { return (ALLOCATION_TYPE)m_Type; } + VkDeviceSize GetAlignment() const { return m_Alignment; } + VkDeviceSize GetSize() const { return m_Size; } + bool IsUserDataString() const { return (m_Flags & FLAG_USER_DATA_STRING) != 0; } + void* GetUserData() const { return m_pUserData; } + void SetUserData(VmaAllocator hAllocator, void* pUserData); + VmaSuballocationType GetSuballocationType() const { return (VmaSuballocationType)m_SuballocationType; } + + VmaDeviceMemoryBlock* GetBlock() const + { + VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); + return m_BlockAllocation.m_Block; + } + VkDeviceSize GetOffset() const; + VkDeviceMemory GetMemory() const; + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + bool IsPersistentMap() const { return (m_MapCount & MAP_COUNT_FLAG_PERSISTENT_MAP) != 0; } + void* GetMappedData() const; + bool CanBecomeLost() const; + + uint32_t GetLastUseFrameIndex() const + { + return m_LastUseFrameIndex.load(); + } + bool CompareExchangeLastUseFrameIndex(uint32_t& expected, uint32_t desired) + { + return m_LastUseFrameIndex.compare_exchange_weak(expected, desired); + } + /* + - If hAllocation.LastUseFrameIndex + frameInUseCount < allocator.CurrentFrameIndex, + makes it lost by setting LastUseFrameIndex = VMA_FRAME_INDEX_LOST and returns true. + - Else, returns false. + + If hAllocation is already lost, assert - you should not call it then. + If hAllocation was not created with CAN_BECOME_LOST_BIT, assert. + */ + bool MakeLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + + void DedicatedAllocCalcStatsInfo(VmaStatInfo& outInfo) + { + VMA_ASSERT(m_Type == ALLOCATION_TYPE_DEDICATED); + outInfo.blockCount = 1; + outInfo.allocationCount = 1; + outInfo.unusedRangeCount = 0; + outInfo.usedBytes = m_Size; + outInfo.unusedBytes = 0; + outInfo.allocationSizeMin = outInfo.allocationSizeMax = m_Size; + outInfo.unusedRangeSizeMin = UINT64_MAX; + outInfo.unusedRangeSizeMax = 0; + } + + void BlockAllocMap(); + void BlockAllocUnmap(); + VkResult DedicatedAllocMap(VmaAllocator hAllocator, void** ppData); + void DedicatedAllocUnmap(VmaAllocator hAllocator); + +#if VMA_STATS_STRING_ENABLED + uint32_t GetCreationFrameIndex() const { return m_CreationFrameIndex; } + uint32_t GetBufferImageUsage() const { return m_BufferImageUsage; } + + void InitBufferImageUsage(uint32_t bufferImageUsage) + { + VMA_ASSERT(m_BufferImageUsage == 0); + m_BufferImageUsage = bufferImageUsage; + } + + void PrintParameters(class VmaJsonWriter& json) const; +#endif + +private: + VkDeviceSize m_Alignment; + VkDeviceSize m_Size; + void* m_pUserData; + VMA_ATOMIC_UINT32 m_LastUseFrameIndex; + uint32_t m_MemoryTypeIndex; + uint8_t m_Type; // ALLOCATION_TYPE + uint8_t m_SuballocationType; // VmaSuballocationType + // Bit 0x80 is set when allocation was created with VMA_ALLOCATION_CREATE_MAPPED_BIT. + // Bits with mask 0x7F are reference counter for vmaMapMemory()/vmaUnmapMemory(). + uint8_t m_MapCount; + uint8_t m_Flags; // enum FLAGS + + // Allocation out of VmaDeviceMemoryBlock. + struct BlockAllocation + { + VmaDeviceMemoryBlock* m_Block; + VkDeviceSize m_Offset; + bool m_CanBecomeLost; + }; + + // Allocation for an object that has its own private VkDeviceMemory. + struct DedicatedAllocation + { + VkDeviceMemory m_hMemory; + void* m_pMappedData; // Not null means memory is mapped. + }; + + union + { + // Allocation out of VmaDeviceMemoryBlock. + BlockAllocation m_BlockAllocation; + // Allocation for an object that has its own private VkDeviceMemory. + DedicatedAllocation m_DedicatedAllocation; + }; + +#if VMA_STATS_STRING_ENABLED + uint32_t m_CreationFrameIndex; + uint32_t m_BufferImageUsage; // 0 if unknown. +#endif + + void FreeUserDataString(VmaAllocator hAllocator); +}; + +/* +Represents a region of VmaDeviceMemoryBlock that is either assigned and returned as +allocated memory block or free. +*/ +struct VmaSuballocation +{ + VkDeviceSize offset; + VkDeviceSize size; + VmaAllocation hAllocation; + VmaSuballocationType type; +}; + +// Comparator for offsets. +struct VmaSuballocationOffsetLess +{ + bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const + { + return lhs.offset < rhs.offset; + } +}; +struct VmaSuballocationOffsetGreater +{ + bool operator()(const VmaSuballocation& lhs, const VmaSuballocation& rhs) const + { + return lhs.offset > rhs.offset; + } +}; + +typedef VmaList< VmaSuballocation, VmaStlAllocator > VmaSuballocationList; + +// Cost of one additional allocation lost, as equivalent in bytes. +static const VkDeviceSize VMA_LOST_ALLOCATION_COST = 1048576; + +enum class VmaAllocationRequestType +{ + Normal, + // Used by "Linear" algorithm. + UpperAddress, + EndOf1st, + EndOf2nd, +}; + +/* +Parameters of planned allocation inside a VmaDeviceMemoryBlock. + +If canMakeOtherLost was false: +- item points to a FREE suballocation. +- itemsToMakeLostCount is 0. + +If canMakeOtherLost was true: +- item points to first of sequence of suballocations, which are either FREE, + or point to VmaAllocations that can become lost. +- itemsToMakeLostCount is the number of VmaAllocations that need to be made lost for + the requested allocation to succeed. +*/ +struct VmaAllocationRequest +{ + VkDeviceSize offset; + VkDeviceSize sumFreeSize; // Sum size of free items that overlap with proposed allocation. + VkDeviceSize sumItemSize; // Sum size of items to make lost that overlap with proposed allocation. + VmaSuballocationList::iterator item; + size_t itemsToMakeLostCount; + void* customData; + VmaAllocationRequestType type; + + VkDeviceSize CalcCost() const + { + return sumItemSize + itemsToMakeLostCount * VMA_LOST_ALLOCATION_COST; + } +}; + +/* +Data structure used for bookkeeping of allocations and unused ranges of memory +in a single VkDeviceMemory block. +*/ +class VmaBlockMetadata +{ +public: + VmaBlockMetadata(VmaAllocator hAllocator); + virtual ~VmaBlockMetadata() { } + virtual void Init(VkDeviceSize size) { m_Size = size; } + + // Validates all data structures inside this object. If not valid, returns false. + virtual bool Validate() const = 0; + VkDeviceSize GetSize() const { return m_Size; } + virtual size_t GetAllocationCount() const = 0; + virtual VkDeviceSize GetSumFreeSize() const = 0; + virtual VkDeviceSize GetUnusedRangeSizeMax() const = 0; + // Returns true if this block is empty - contains only single free suballocation. + virtual bool IsEmpty() const = 0; + + virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const = 0; + // Shouldn't modify blockCount. + virtual void AddPoolStats(VmaPoolStats& inoutStats) const = 0; + +#if VMA_STATS_STRING_ENABLED + virtual void PrintDetailedMap(class VmaJsonWriter& json) const = 0; +#endif + + // Tries to find a place for suballocation with given parameters inside this block. + // If succeeded, fills pAllocationRequest and returns true. + // If failed, returns false. + virtual bool CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + // Always one of VMA_ALLOCATION_CREATE_STRATEGY_* or VMA_ALLOCATION_INTERNAL_STRATEGY_* flags. + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) = 0; + + virtual bool MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest) = 0; + + virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) = 0; + + virtual VkResult CheckCorruption(const void* pBlockData) = 0; + + // Makes actual allocation based on request. Request must already be checked and valid. + virtual void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation) = 0; + + // Frees suballocation assigned to given memory region. + virtual void Free(const VmaAllocation allocation) = 0; + virtual void FreeAtOffset(VkDeviceSize offset) = 0; + +protected: + const VkAllocationCallbacks* GetAllocationCallbacks() const { return m_pAllocationCallbacks; } + +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap_Begin(class VmaJsonWriter& json, + VkDeviceSize unusedBytes, + size_t allocationCount, + size_t unusedRangeCount) const; + void PrintDetailedMap_Allocation(class VmaJsonWriter& json, + VkDeviceSize offset, + VmaAllocation hAllocation) const; + void PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, + VkDeviceSize offset, + VkDeviceSize size) const; + void PrintDetailedMap_End(class VmaJsonWriter& json) const; +#endif + +private: + VkDeviceSize m_Size; + const VkAllocationCallbacks* m_pAllocationCallbacks; +}; + +#define VMA_VALIDATE(cond) do { if(!(cond)) { \ + VMA_ASSERT(0 && "Validation failed: " #cond); \ + return false; \ + } } while(false) + +class VmaBlockMetadata_Generic : public VmaBlockMetadata +{ + VMA_CLASS_NO_COPY(VmaBlockMetadata_Generic) +public: + VmaBlockMetadata_Generic(VmaAllocator hAllocator); + virtual ~VmaBlockMetadata_Generic(); + virtual void Init(VkDeviceSize size); + + virtual bool Validate() const; + virtual size_t GetAllocationCount() const { return m_Suballocations.size() - m_FreeCount; } + virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize; } + virtual VkDeviceSize GetUnusedRangeSizeMax() const; + virtual bool IsEmpty() const; + + virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; + virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + +#if VMA_STATS_STRING_ENABLED + virtual void PrintDetailedMap(class VmaJsonWriter& json) const; +#endif + + virtual bool CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); + + virtual bool MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest); + + virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + + virtual VkResult CheckCorruption(const void* pBlockData); + + virtual void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation); + + virtual void Free(const VmaAllocation allocation); + virtual void FreeAtOffset(VkDeviceSize offset); + + //////////////////////////////////////////////////////////////////////////////// + // For defragmentation + + bool IsBufferImageGranularityConflictPossible( + VkDeviceSize bufferImageGranularity, + VmaSuballocationType& inOutPrevSuballocType) const; + +private: + friend class VmaDefragmentationAlgorithm_Generic; + friend class VmaDefragmentationAlgorithm_Fast; + + uint32_t m_FreeCount; + VkDeviceSize m_SumFreeSize; + VmaSuballocationList m_Suballocations; + // Suballocations that are free and have size greater than certain threshold. + // Sorted by size, ascending. + VmaVector< VmaSuballocationList::iterator, VmaStlAllocator< VmaSuballocationList::iterator > > m_FreeSuballocationsBySize; + + bool ValidateFreeSuballocationList() const; + + // Checks if requested suballocation with given parameters can be placed in given pFreeSuballocItem. + // If yes, fills pOffset and returns true. If no, returns false. + bool CheckAllocation( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaSuballocationList::const_iterator suballocItem, + bool canMakeOtherLost, + VkDeviceSize* pOffset, + size_t* itemsToMakeLostCount, + VkDeviceSize* pSumFreeSize, + VkDeviceSize* pSumItemSize) const; + // Given free suballocation, it merges it with following one, which must also be free. + void MergeFreeWithNext(VmaSuballocationList::iterator item); + // Releases given suballocation, making it free. + // Merges it with adjacent free suballocations if applicable. + // Returns iterator to new free suballocation at this place. + VmaSuballocationList::iterator FreeSuballocation(VmaSuballocationList::iterator suballocItem); + // Given free suballocation, it inserts it into sorted list of + // m_FreeSuballocationsBySize if it's suitable. + void RegisterFreeSuballocation(VmaSuballocationList::iterator item); + // Given free suballocation, it removes it from sorted list of + // m_FreeSuballocationsBySize if it's suitable. + void UnregisterFreeSuballocation(VmaSuballocationList::iterator item); +}; + +/* +Allocations and their references in internal data structure look like this: + +if(m_2ndVectorMode == SECOND_VECTOR_EMPTY): + + 0 +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | + | | + | | +GetSize() +-------+ + +if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER): + + 0 +-------+ + | Alloc | 2nd[0] + +-------+ + | Alloc | 2nd[1] + +-------+ + | ... | + +-------+ + | Alloc | 2nd[2nd.size() - 1] + +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | +GetSize() +-------+ + +if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK): + + 0 +-------+ + | | + | | + | | + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount] + +-------+ + | Alloc | 1st[m_1stNullItemsBeginCount + 1] + +-------+ + | ... | + +-------+ + | Alloc | 1st[1st.size() - 1] + +-------+ + | | + | | + | | + +-------+ + | Alloc | 2nd[2nd.size() - 1] + +-------+ + | ... | + +-------+ + | Alloc | 2nd[1] + +-------+ + | Alloc | 2nd[0] +GetSize() +-------+ + +*/ +class VmaBlockMetadata_Linear : public VmaBlockMetadata +{ + VMA_CLASS_NO_COPY(VmaBlockMetadata_Linear) +public: + VmaBlockMetadata_Linear(VmaAllocator hAllocator); + virtual ~VmaBlockMetadata_Linear(); + virtual void Init(VkDeviceSize size); + + virtual bool Validate() const; + virtual size_t GetAllocationCount() const; + virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize; } + virtual VkDeviceSize GetUnusedRangeSizeMax() const; + virtual bool IsEmpty() const { return GetAllocationCount() == 0; } + + virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; + virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + +#if VMA_STATS_STRING_ENABLED + virtual void PrintDetailedMap(class VmaJsonWriter& json) const; +#endif + + virtual bool CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); + + virtual bool MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest); + + virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + + virtual VkResult CheckCorruption(const void* pBlockData); + + virtual void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation); + + virtual void Free(const VmaAllocation allocation); + virtual void FreeAtOffset(VkDeviceSize offset); + +private: + /* + There are two suballocation vectors, used in ping-pong way. + The one with index m_1stVectorIndex is called 1st. + The one with index (m_1stVectorIndex ^ 1) is called 2nd. + 2nd can be non-empty only when 1st is not empty. + When 2nd is not empty, m_2ndVectorMode indicates its mode of operation. + */ + typedef VmaVector< VmaSuballocation, VmaStlAllocator > SuballocationVectorType; + + enum SECOND_VECTOR_MODE + { + SECOND_VECTOR_EMPTY, + /* + Suballocations in 2nd vector are created later than the ones in 1st, but they + all have smaller offset. + */ + SECOND_VECTOR_RING_BUFFER, + /* + Suballocations in 2nd vector are upper side of double stack. + They all have offsets higher than those in 1st vector. + Top of this stack means smaller offsets, but higher indices in this vector. + */ + SECOND_VECTOR_DOUBLE_STACK, + }; + + VkDeviceSize m_SumFreeSize; + SuballocationVectorType m_Suballocations0, m_Suballocations1; + uint32_t m_1stVectorIndex; + SECOND_VECTOR_MODE m_2ndVectorMode; + + SuballocationVectorType& AccessSuballocations1st() { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } + SuballocationVectorType& AccessSuballocations2nd() { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } + const SuballocationVectorType& AccessSuballocations1st() const { return m_1stVectorIndex ? m_Suballocations1 : m_Suballocations0; } + const SuballocationVectorType& AccessSuballocations2nd() const { return m_1stVectorIndex ? m_Suballocations0 : m_Suballocations1; } + + // Number of items in 1st vector with hAllocation = null at the beginning. + size_t m_1stNullItemsBeginCount; + // Number of other items in 1st vector with hAllocation = null somewhere in the middle. + size_t m_1stNullItemsMiddleCount; + // Number of items in 2nd vector with hAllocation = null. + size_t m_2ndNullItemsCount; + + bool ShouldCompact1st() const; + void CleanupAfterFree(); + + bool CreateAllocationRequest_LowerAddress( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); + bool CreateAllocationRequest_UpperAddress( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); +}; + +/* +- GetSize() is the original size of allocated memory block. +- m_UsableSize is this size aligned down to a power of two. + All allocations and calculations happen relative to m_UsableSize. +- GetUnusableSize() is the difference between them. + It is repoted as separate, unused range, not available for allocations. + +Node at level 0 has size = m_UsableSize. +Each next level contains nodes with size 2 times smaller than current level. +m_LevelCount is the maximum number of levels to use in the current object. +*/ +class VmaBlockMetadata_Buddy : public VmaBlockMetadata +{ + VMA_CLASS_NO_COPY(VmaBlockMetadata_Buddy) +public: + VmaBlockMetadata_Buddy(VmaAllocator hAllocator); + virtual ~VmaBlockMetadata_Buddy(); + virtual void Init(VkDeviceSize size); + + virtual bool Validate() const; + virtual size_t GetAllocationCount() const { return m_AllocationCount; } + virtual VkDeviceSize GetSumFreeSize() const { return m_SumFreeSize + GetUnusableSize(); } + virtual VkDeviceSize GetUnusedRangeSizeMax() const; + virtual bool IsEmpty() const { return m_Root->type == Node::TYPE_FREE; } + + virtual void CalcAllocationStatInfo(VmaStatInfo& outInfo) const; + virtual void AddPoolStats(VmaPoolStats& inoutStats) const; + +#if VMA_STATS_STRING_ENABLED + virtual void PrintDetailedMap(class VmaJsonWriter& json) const; +#endif + + virtual bool CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest); + + virtual bool MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest); + + virtual uint32_t MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount); + + virtual VkResult CheckCorruption(const void* pBlockData) { return VK_ERROR_FEATURE_NOT_PRESENT; } + + virtual void Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation); + + virtual void Free(const VmaAllocation allocation) { FreeAtOffset(allocation, allocation->GetOffset()); } + virtual void FreeAtOffset(VkDeviceSize offset) { FreeAtOffset(VMA_NULL, offset); } + +private: + static const VkDeviceSize MIN_NODE_SIZE = 32; + static const size_t MAX_LEVELS = 30; + + struct ValidationContext + { + size_t calculatedAllocationCount; + size_t calculatedFreeCount; + VkDeviceSize calculatedSumFreeSize; + + ValidationContext() : + calculatedAllocationCount(0), + calculatedFreeCount(0), + calculatedSumFreeSize(0) { } + }; + + struct Node + { + VkDeviceSize offset; + enum TYPE + { + TYPE_FREE, + TYPE_ALLOCATION, + TYPE_SPLIT, + TYPE_COUNT + } type; + Node* parent; + Node* buddy; + + union + { + struct + { + Node* prev; + Node* next; + } free; + struct + { + VmaAllocation alloc; + } allocation; + struct + { + Node* leftChild; + } split; + }; + }; + + // Size of the memory block aligned down to a power of two. + VkDeviceSize m_UsableSize; + uint32_t m_LevelCount; + + Node* m_Root; + struct { + Node* front; + Node* back; + } m_FreeList[MAX_LEVELS]; + // Number of nodes in the tree with type == TYPE_ALLOCATION. + size_t m_AllocationCount; + // Number of nodes in the tree with type == TYPE_FREE. + size_t m_FreeCount; + // This includes space wasted due to internal fragmentation. Doesn't include unusable size. + VkDeviceSize m_SumFreeSize; + + VkDeviceSize GetUnusableSize() const { return GetSize() - m_UsableSize; } + void DeleteNode(Node* node); + bool ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const; + uint32_t AllocSizeToLevel(VkDeviceSize allocSize) const; + inline VkDeviceSize LevelToNodeSize(uint32_t level) const { return m_UsableSize >> level; } + // Alloc passed just for validation. Can be null. + void FreeAtOffset(VmaAllocation alloc, VkDeviceSize offset); + void CalcAllocationStatInfoNode(VmaStatInfo& outInfo, const Node* node, VkDeviceSize levelNodeSize) const; + // Adds node to the front of FreeList at given level. + // node->type must be FREE. + // node->free.prev, next can be undefined. + void AddToFreeListFront(uint32_t level, Node* node); + // Removes node from FreeList at given level. + // node->type must be FREE. + // node->free.prev, next stay untouched. + void RemoveFromFreeList(uint32_t level, Node* node); + +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const; +#endif +}; + +/* +Represents a single block of device memory (`VkDeviceMemory`) with all the +data about its regions (aka suballocations, #VmaAllocation), assigned and free. + +Thread-safety: This class must be externally synchronized. +*/ +class VmaDeviceMemoryBlock +{ + VMA_CLASS_NO_COPY(VmaDeviceMemoryBlock) +public: + VmaBlockMetadata* m_pMetadata; + + VmaDeviceMemoryBlock(VmaAllocator hAllocator); + + ~VmaDeviceMemoryBlock() + { + VMA_ASSERT(m_MapCount == 0 && "VkDeviceMemory block is being destroyed while it is still mapped."); + VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); + } + + // Always call after construction. + void Init( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t newMemoryTypeIndex, + VkDeviceMemory newMemory, + VkDeviceSize newSize, + uint32_t id, + uint32_t algorithm); + // Always call before destruction. + void Destroy(VmaAllocator allocator); + + VmaPool GetParentPool() const { return m_hParentPool; } + VkDeviceMemory GetDeviceMemory() const { return m_hMemory; } + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + uint32_t GetId() const { return m_Id; } + void* GetMappedData() const { return m_pMappedData; } + + // Validates all data structures inside this object. If not valid, returns false. + bool Validate() const; + + VkResult CheckCorruption(VmaAllocator hAllocator); + + // ppData can be null. + VkResult Map(VmaAllocator hAllocator, uint32_t count, void** ppData); + void Unmap(VmaAllocator hAllocator, uint32_t count); + + VkResult WriteMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); + VkResult ValidateMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize); + + VkResult BindBufferMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext); + VkResult BindImageMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext); + +private: + VmaPool m_hParentPool; // VK_NULL_HANDLE if not belongs to custom pool. + uint32_t m_MemoryTypeIndex; + uint32_t m_Id; + VkDeviceMemory m_hMemory; + + /* + Protects access to m_hMemory so it's not used by multiple threads simultaneously, e.g. vkMapMemory, vkBindBufferMemory. + Also protects m_MapCount, m_pMappedData. + Allocations, deallocations, any change in m_pMetadata is protected by parent's VmaBlockVector::m_Mutex. + */ + VMA_MUTEX m_Mutex; + uint32_t m_MapCount; + void* m_pMappedData; +}; + +struct VmaPointerLess +{ + bool operator()(const void* lhs, const void* rhs) const + { + return lhs < rhs; + } +}; + +struct VmaDefragmentationMove +{ + size_t srcBlockIndex; + size_t dstBlockIndex; + VkDeviceSize srcOffset; + VkDeviceSize dstOffset; + VkDeviceSize size; + VmaAllocation hAllocation; + VmaDeviceMemoryBlock* pSrcBlock; + VmaDeviceMemoryBlock* pDstBlock; +}; + +class VmaDefragmentationAlgorithm; + +/* +Sequence of VmaDeviceMemoryBlock. Represents memory blocks allocated for a specific +Vulkan memory type. + +Synchronized internally with a mutex. +*/ +struct VmaBlockVector +{ + VMA_CLASS_NO_COPY(VmaBlockVector) +public: + VmaBlockVector( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceSize preferredBlockSize, + size_t minBlockCount, + size_t maxBlockCount, + VkDeviceSize bufferImageGranularity, + uint32_t frameInUseCount, + bool explicitBlockSize, + uint32_t algorithm); + ~VmaBlockVector(); + + VkResult CreateMinBlocks(); + + VmaAllocator GetAllocator() const { return m_hAllocator; } + VmaPool GetParentPool() const { return m_hParentPool; } + bool IsCustomPool() const { return m_hParentPool != VMA_NULL; } + uint32_t GetMemoryTypeIndex() const { return m_MemoryTypeIndex; } + VkDeviceSize GetPreferredBlockSize() const { return m_PreferredBlockSize; } + VkDeviceSize GetBufferImageGranularity() const { return m_BufferImageGranularity; } + uint32_t GetFrameInUseCount() const { return m_FrameInUseCount; } + uint32_t GetAlgorithm() const { return m_Algorithm; } + + void GetPoolStats(VmaPoolStats* pStats); + + bool IsEmpty(); + bool IsCorruptionDetectionEnabled() const; + + VkResult Allocate( + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations); + + void Free(const VmaAllocation hAllocation); + + // Adds statistics of this BlockVector to pStats. + void AddStats(VmaStats* pStats); + +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json); +#endif + + void MakePoolAllocationsLost( + uint32_t currentFrameIndex, + size_t* pLostAllocationCount); + VkResult CheckCorruption(); + + // Saves results in pCtx->res. + void Defragment( + class VmaBlockVectorDefragmentationContext* pCtx, + VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags, + VkDeviceSize& maxCpuBytesToMove, uint32_t& maxCpuAllocationsToMove, + VkDeviceSize& maxGpuBytesToMove, uint32_t& maxGpuAllocationsToMove, + VkCommandBuffer commandBuffer); + void DefragmentationEnd( + class VmaBlockVectorDefragmentationContext* pCtx, + uint32_t flags, + VmaDefragmentationStats* pStats); + + uint32_t ProcessDefragmentations( + class VmaBlockVectorDefragmentationContext *pCtx, + VmaDefragmentationPassMoveInfo* pMove, uint32_t maxMoves); + + void CommitDefragmentations( + class VmaBlockVectorDefragmentationContext *pCtx, + VmaDefragmentationStats* pStats); + + //////////////////////////////////////////////////////////////////////////////// + // To be used only while the m_Mutex is locked. Used during defragmentation. + + size_t GetBlockCount() const { return m_Blocks.size(); } + VmaDeviceMemoryBlock* GetBlock(size_t index) const { return m_Blocks[index]; } + size_t CalcAllocationCount() const; + bool IsBufferImageGranularityConflictPossible() const; + +private: + friend class VmaDefragmentationAlgorithm_Generic; + + const VmaAllocator m_hAllocator; + const VmaPool m_hParentPool; + const uint32_t m_MemoryTypeIndex; + const VkDeviceSize m_PreferredBlockSize; + const size_t m_MinBlockCount; + const size_t m_MaxBlockCount; + const VkDeviceSize m_BufferImageGranularity; + const uint32_t m_FrameInUseCount; + const bool m_ExplicitBlockSize; + const uint32_t m_Algorithm; + VMA_RW_MUTEX m_Mutex; + + /* There can be at most one allocation that is completely empty (except when minBlockCount > 0) - + a hysteresis to avoid pessimistic case of alternating creation and destruction of a VkDeviceMemory. */ + bool m_HasEmptyBlock; + // Incrementally sorted by sumFreeSize, ascending. + VmaVector< VmaDeviceMemoryBlock*, VmaStlAllocator > m_Blocks; + uint32_t m_NextBlockId; + + VkDeviceSize CalcMaxBlockSize() const; + + // Finds and removes given block from vector. + void Remove(VmaDeviceMemoryBlock* pBlock); + + // Performs single step in sorting m_Blocks. They may not be fully sorted + // after this call. + void IncrementallySortBlocks(); + + VkResult AllocatePage( + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation); + + // To be used only without CAN_MAKE_OTHER_LOST flag. + VkResult AllocateFromBlock( + VmaDeviceMemoryBlock* pBlock, + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + uint32_t strategy, + VmaAllocation* pAllocation); + + VkResult CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex); + + // Saves result to pCtx->res. + void ApplyDefragmentationMovesCpu( + class VmaBlockVectorDefragmentationContext* pDefragCtx, + const VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves); + // Saves result to pCtx->res. + void ApplyDefragmentationMovesGpu( + class VmaBlockVectorDefragmentationContext* pDefragCtx, + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkCommandBuffer commandBuffer); + + /* + Used during defragmentation. pDefragmentationStats is optional. It's in/out + - updated with new data. + */ + void FreeEmptyBlocks(VmaDefragmentationStats* pDefragmentationStats); + + void UpdateHasEmptyBlock(); +}; + +struct VmaPool_T +{ + VMA_CLASS_NO_COPY(VmaPool_T) +public: + VmaBlockVector m_BlockVector; + + VmaPool_T( + VmaAllocator hAllocator, + const VmaPoolCreateInfo& createInfo, + VkDeviceSize preferredBlockSize); + ~VmaPool_T(); + + uint32_t GetId() const { return m_Id; } + void SetId(uint32_t id) { VMA_ASSERT(m_Id == 0); m_Id = id; } + + const char* GetName() const { return m_Name; } + void SetName(const char* pName); + +#if VMA_STATS_STRING_ENABLED + //void PrintDetailedMap(class VmaStringBuilder& sb); +#endif + +private: + uint32_t m_Id; + char* m_Name; +}; + +/* +Performs defragmentation: + +- Updates `pBlockVector->m_pMetadata`. +- Updates allocations by calling ChangeBlockAllocation() or ChangeOffset(). +- Does not move actual data, only returns requested moves as `moves`. +*/ +class VmaDefragmentationAlgorithm +{ + VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm) +public: + VmaDefragmentationAlgorithm( + VmaAllocator hAllocator, + VmaBlockVector* pBlockVector, + uint32_t currentFrameIndex) : + m_hAllocator(hAllocator), + m_pBlockVector(pBlockVector), + m_CurrentFrameIndex(currentFrameIndex) + { + } + virtual ~VmaDefragmentationAlgorithm() + { + } + + virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) = 0; + virtual void AddAll() = 0; + + virtual VkResult Defragment( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + VmaDefragmentationFlags flags) = 0; + + virtual VkDeviceSize GetBytesMoved() const = 0; + virtual uint32_t GetAllocationsMoved() const = 0; + +protected: + VmaAllocator const m_hAllocator; + VmaBlockVector* const m_pBlockVector; + const uint32_t m_CurrentFrameIndex; + + struct AllocationInfo + { + VmaAllocation m_hAllocation; + VkBool32* m_pChanged; + + AllocationInfo() : + m_hAllocation(VK_NULL_HANDLE), + m_pChanged(VMA_NULL) + { + } + AllocationInfo(VmaAllocation hAlloc, VkBool32* pChanged) : + m_hAllocation(hAlloc), + m_pChanged(pChanged) + { + } + }; +}; + +class VmaDefragmentationAlgorithm_Generic : public VmaDefragmentationAlgorithm +{ + VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm_Generic) +public: + VmaDefragmentationAlgorithm_Generic( + VmaAllocator hAllocator, + VmaBlockVector* pBlockVector, + uint32_t currentFrameIndex, + bool overlappingMoveSupported); + virtual ~VmaDefragmentationAlgorithm_Generic(); + + virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged); + virtual void AddAll() { m_AllAllocations = true; } + + virtual VkResult Defragment( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + VmaDefragmentationFlags flags); + + virtual VkDeviceSize GetBytesMoved() const { return m_BytesMoved; } + virtual uint32_t GetAllocationsMoved() const { return m_AllocationsMoved; } + +private: + uint32_t m_AllocationCount; + bool m_AllAllocations; + + VkDeviceSize m_BytesMoved; + uint32_t m_AllocationsMoved; + + struct AllocationInfoSizeGreater + { + bool operator()(const AllocationInfo& lhs, const AllocationInfo& rhs) const + { + return lhs.m_hAllocation->GetSize() > rhs.m_hAllocation->GetSize(); + } + }; + + struct AllocationInfoOffsetGreater + { + bool operator()(const AllocationInfo& lhs, const AllocationInfo& rhs) const + { + return lhs.m_hAllocation->GetOffset() > rhs.m_hAllocation->GetOffset(); + } + }; + + struct BlockInfo + { + size_t m_OriginalBlockIndex; + VmaDeviceMemoryBlock* m_pBlock; + bool m_HasNonMovableAllocations; + VmaVector< AllocationInfo, VmaStlAllocator > m_Allocations; + + BlockInfo(const VkAllocationCallbacks* pAllocationCallbacks) : + m_OriginalBlockIndex(SIZE_MAX), + m_pBlock(VMA_NULL), + m_HasNonMovableAllocations(true), + m_Allocations(pAllocationCallbacks) + { + } + + void CalcHasNonMovableAllocations() + { + const size_t blockAllocCount = m_pBlock->m_pMetadata->GetAllocationCount(); + const size_t defragmentAllocCount = m_Allocations.size(); + m_HasNonMovableAllocations = blockAllocCount != defragmentAllocCount; + } + + void SortAllocationsBySizeDescending() + { + VMA_SORT(m_Allocations.begin(), m_Allocations.end(), AllocationInfoSizeGreater()); + } + + void SortAllocationsByOffsetDescending() + { + VMA_SORT(m_Allocations.begin(), m_Allocations.end(), AllocationInfoOffsetGreater()); + } + }; + + struct BlockPointerLess + { + bool operator()(const BlockInfo* pLhsBlockInfo, const VmaDeviceMemoryBlock* pRhsBlock) const + { + return pLhsBlockInfo->m_pBlock < pRhsBlock; + } + bool operator()(const BlockInfo* pLhsBlockInfo, const BlockInfo* pRhsBlockInfo) const + { + return pLhsBlockInfo->m_pBlock < pRhsBlockInfo->m_pBlock; + } + }; + + // 1. Blocks with some non-movable allocations go first. + // 2. Blocks with smaller sumFreeSize go first. + struct BlockInfoCompareMoveDestination + { + bool operator()(const BlockInfo* pLhsBlockInfo, const BlockInfo* pRhsBlockInfo) const + { + if(pLhsBlockInfo->m_HasNonMovableAllocations && !pRhsBlockInfo->m_HasNonMovableAllocations) + { + return true; + } + if(!pLhsBlockInfo->m_HasNonMovableAllocations && pRhsBlockInfo->m_HasNonMovableAllocations) + { + return false; + } + if(pLhsBlockInfo->m_pBlock->m_pMetadata->GetSumFreeSize() < pRhsBlockInfo->m_pBlock->m_pMetadata->GetSumFreeSize()) + { + return true; + } + return false; + } + }; + + typedef VmaVector< BlockInfo*, VmaStlAllocator > BlockInfoVector; + BlockInfoVector m_Blocks; + + VkResult DefragmentRound( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + bool freeOldAllocations); + + size_t CalcBlocksWithNonMovableCount() const; + + static bool MoveMakesSense( + size_t dstBlockIndex, VkDeviceSize dstOffset, + size_t srcBlockIndex, VkDeviceSize srcOffset); +}; + +class VmaDefragmentationAlgorithm_Fast : public VmaDefragmentationAlgorithm +{ + VMA_CLASS_NO_COPY(VmaDefragmentationAlgorithm_Fast) +public: + VmaDefragmentationAlgorithm_Fast( + VmaAllocator hAllocator, + VmaBlockVector* pBlockVector, + uint32_t currentFrameIndex, + bool overlappingMoveSupported); + virtual ~VmaDefragmentationAlgorithm_Fast(); + + virtual void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) { ++m_AllocationCount; } + virtual void AddAll() { m_AllAllocations = true; } + + virtual VkResult Defragment( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + VmaDefragmentationFlags flags); + + virtual VkDeviceSize GetBytesMoved() const { return m_BytesMoved; } + virtual uint32_t GetAllocationsMoved() const { return m_AllocationsMoved; } + +private: + struct BlockInfo + { + size_t origBlockIndex; + }; + + class FreeSpaceDatabase + { + public: + FreeSpaceDatabase() + { + FreeSpace s = {}; + s.blockInfoIndex = SIZE_MAX; + for(size_t i = 0; i < MAX_COUNT; ++i) + { + m_FreeSpaces[i] = s; + } + } + + void Register(size_t blockInfoIndex, VkDeviceSize offset, VkDeviceSize size) + { + if(size < VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + return; + } + + // Find first invalid or the smallest structure. + size_t bestIndex = SIZE_MAX; + for(size_t i = 0; i < MAX_COUNT; ++i) + { + // Empty structure. + if(m_FreeSpaces[i].blockInfoIndex == SIZE_MAX) + { + bestIndex = i; + break; + } + if(m_FreeSpaces[i].size < size && + (bestIndex == SIZE_MAX || m_FreeSpaces[bestIndex].size > m_FreeSpaces[i].size)) + { + bestIndex = i; + } + } + + if(bestIndex != SIZE_MAX) + { + m_FreeSpaces[bestIndex].blockInfoIndex = blockInfoIndex; + m_FreeSpaces[bestIndex].offset = offset; + m_FreeSpaces[bestIndex].size = size; + } + } + + bool Fetch(VkDeviceSize alignment, VkDeviceSize size, + size_t& outBlockInfoIndex, VkDeviceSize& outDstOffset) + { + size_t bestIndex = SIZE_MAX; + VkDeviceSize bestFreeSpaceAfter = 0; + for(size_t i = 0; i < MAX_COUNT; ++i) + { + // Structure is valid. + if(m_FreeSpaces[i].blockInfoIndex != SIZE_MAX) + { + const VkDeviceSize dstOffset = VmaAlignUp(m_FreeSpaces[i].offset, alignment); + // Allocation fits into this structure. + if(dstOffset + size <= m_FreeSpaces[i].offset + m_FreeSpaces[i].size) + { + const VkDeviceSize freeSpaceAfter = (m_FreeSpaces[i].offset + m_FreeSpaces[i].size) - + (dstOffset + size); + if(bestIndex == SIZE_MAX || freeSpaceAfter > bestFreeSpaceAfter) + { + bestIndex = i; + bestFreeSpaceAfter = freeSpaceAfter; + } + } + } + } + + if(bestIndex != SIZE_MAX) + { + outBlockInfoIndex = m_FreeSpaces[bestIndex].blockInfoIndex; + outDstOffset = VmaAlignUp(m_FreeSpaces[bestIndex].offset, alignment); + + if(bestFreeSpaceAfter >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + // Leave this structure for remaining empty space. + const VkDeviceSize alignmentPlusSize = (outDstOffset - m_FreeSpaces[bestIndex].offset) + size; + m_FreeSpaces[bestIndex].offset += alignmentPlusSize; + m_FreeSpaces[bestIndex].size -= alignmentPlusSize; + } + else + { + // This structure becomes invalid. + m_FreeSpaces[bestIndex].blockInfoIndex = SIZE_MAX; + } + + return true; + } + + return false; + } + + private: + static const size_t MAX_COUNT = 4; + + struct FreeSpace + { + size_t blockInfoIndex; // SIZE_MAX means this structure is invalid. + VkDeviceSize offset; + VkDeviceSize size; + } m_FreeSpaces[MAX_COUNT]; + }; + + const bool m_OverlappingMoveSupported; + + uint32_t m_AllocationCount; + bool m_AllAllocations; + + VkDeviceSize m_BytesMoved; + uint32_t m_AllocationsMoved; + + VmaVector< BlockInfo, VmaStlAllocator > m_BlockInfos; + + void PreprocessMetadata(); + void PostprocessMetadata(); + void InsertSuballoc(VmaBlockMetadata_Generic* pMetadata, const VmaSuballocation& suballoc); +}; + +struct VmaBlockDefragmentationContext +{ + enum BLOCK_FLAG + { + BLOCK_FLAG_USED = 0x00000001, + }; + uint32_t flags; + VkBuffer hBuffer; +}; + +class VmaBlockVectorDefragmentationContext +{ + VMA_CLASS_NO_COPY(VmaBlockVectorDefragmentationContext) +public: + VkResult res; + bool mutexLocked; + VmaVector< VmaBlockDefragmentationContext, VmaStlAllocator > blockContexts; + VmaVector< VmaDefragmentationMove, VmaStlAllocator > defragmentationMoves; + uint32_t defragmentationMovesProcessed; + uint32_t defragmentationMovesCommitted; + bool hasDefragmentationPlan; + + VmaBlockVectorDefragmentationContext( + VmaAllocator hAllocator, + VmaPool hCustomPool, // Optional. + VmaBlockVector* pBlockVector, + uint32_t currFrameIndex); + ~VmaBlockVectorDefragmentationContext(); + + VmaPool GetCustomPool() const { return m_hCustomPool; } + VmaBlockVector* GetBlockVector() const { return m_pBlockVector; } + VmaDefragmentationAlgorithm* GetAlgorithm() const { return m_pAlgorithm; } + + void AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged); + void AddAll() { m_AllAllocations = true; } + + void Begin(bool overlappingMoveSupported, VmaDefragmentationFlags flags); + +private: + const VmaAllocator m_hAllocator; + // Null if not from custom pool. + const VmaPool m_hCustomPool; + // Redundant, for convenience not to fetch from m_hCustomPool->m_BlockVector or m_hAllocator->m_pBlockVectors. + VmaBlockVector* const m_pBlockVector; + const uint32_t m_CurrFrameIndex; + // Owner of this object. + VmaDefragmentationAlgorithm* m_pAlgorithm; + + struct AllocInfo + { + VmaAllocation hAlloc; + VkBool32* pChanged; + }; + // Used between constructor and Begin. + VmaVector< AllocInfo, VmaStlAllocator > m_Allocations; + bool m_AllAllocations; +}; + +struct VmaDefragmentationContext_T +{ +private: + VMA_CLASS_NO_COPY(VmaDefragmentationContext_T) +public: + VmaDefragmentationContext_T( + VmaAllocator hAllocator, + uint32_t currFrameIndex, + uint32_t flags, + VmaDefragmentationStats* pStats); + ~VmaDefragmentationContext_T(); + + void AddPools(uint32_t poolCount, const VmaPool* pPools); + void AddAllocations( + uint32_t allocationCount, + const VmaAllocation* pAllocations, + VkBool32* pAllocationsChanged); + + /* + Returns: + - `VK_SUCCESS` if succeeded and object can be destroyed immediately. + - `VK_NOT_READY` if succeeded but the object must remain alive until vmaDefragmentationEnd(). + - Negative value if error occured and object can be destroyed immediately. + */ + VkResult Defragment( + VkDeviceSize maxCpuBytesToMove, uint32_t maxCpuAllocationsToMove, + VkDeviceSize maxGpuBytesToMove, uint32_t maxGpuAllocationsToMove, + VkCommandBuffer commandBuffer, VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags); + + VkResult DefragmentPassBegin(VmaDefragmentationPassInfo* pInfo); + VkResult DefragmentPassEnd(); + +private: + const VmaAllocator m_hAllocator; + const uint32_t m_CurrFrameIndex; + const uint32_t m_Flags; + VmaDefragmentationStats* const m_pStats; + + VkDeviceSize m_MaxCpuBytesToMove; + uint32_t m_MaxCpuAllocationsToMove; + VkDeviceSize m_MaxGpuBytesToMove; + uint32_t m_MaxGpuAllocationsToMove; + + // Owner of these objects. + VmaBlockVectorDefragmentationContext* m_DefaultPoolContexts[VK_MAX_MEMORY_TYPES]; + // Owner of these objects. + VmaVector< VmaBlockVectorDefragmentationContext*, VmaStlAllocator > m_CustomPoolContexts; +}; + +#if VMA_RECORDING_ENABLED + +class VmaRecorder +{ +public: + VmaRecorder(); + VkResult Init(const VmaRecordSettings& settings, bool useMutex); + void WriteConfiguration( + const VkPhysicalDeviceProperties& devProps, + const VkPhysicalDeviceMemoryProperties& memProps, + uint32_t vulkanApiVersion, + bool dedicatedAllocationExtensionEnabled, + bool bindMemory2ExtensionEnabled, + bool memoryBudgetExtensionEnabled, + bool deviceCoherentMemoryExtensionEnabled); + ~VmaRecorder(); + + void RecordCreateAllocator(uint32_t frameIndex); + void RecordDestroyAllocator(uint32_t frameIndex); + void RecordCreatePool(uint32_t frameIndex, + const VmaPoolCreateInfo& createInfo, + VmaPool pool); + void RecordDestroyPool(uint32_t frameIndex, VmaPool pool); + void RecordAllocateMemory(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation); + void RecordAllocateMemoryPages(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + const VmaAllocationCreateInfo& createInfo, + uint64_t allocationCount, + const VmaAllocation* pAllocations); + void RecordAllocateMemoryForBuffer(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation); + void RecordAllocateMemoryForImage(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation); + void RecordFreeMemory(uint32_t frameIndex, + VmaAllocation allocation); + void RecordFreeMemoryPages(uint32_t frameIndex, + uint64_t allocationCount, + const VmaAllocation* pAllocations); + void RecordSetAllocationUserData(uint32_t frameIndex, + VmaAllocation allocation, + const void* pUserData); + void RecordCreateLostAllocation(uint32_t frameIndex, + VmaAllocation allocation); + void RecordMapMemory(uint32_t frameIndex, + VmaAllocation allocation); + void RecordUnmapMemory(uint32_t frameIndex, + VmaAllocation allocation); + void RecordFlushAllocation(uint32_t frameIndex, + VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size); + void RecordInvalidateAllocation(uint32_t frameIndex, + VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size); + void RecordCreateBuffer(uint32_t frameIndex, + const VkBufferCreateInfo& bufCreateInfo, + const VmaAllocationCreateInfo& allocCreateInfo, + VmaAllocation allocation); + void RecordCreateImage(uint32_t frameIndex, + const VkImageCreateInfo& imageCreateInfo, + const VmaAllocationCreateInfo& allocCreateInfo, + VmaAllocation allocation); + void RecordDestroyBuffer(uint32_t frameIndex, + VmaAllocation allocation); + void RecordDestroyImage(uint32_t frameIndex, + VmaAllocation allocation); + void RecordTouchAllocation(uint32_t frameIndex, + VmaAllocation allocation); + void RecordGetAllocationInfo(uint32_t frameIndex, + VmaAllocation allocation); + void RecordMakePoolAllocationsLost(uint32_t frameIndex, + VmaPool pool); + void RecordDefragmentationBegin(uint32_t frameIndex, + const VmaDefragmentationInfo2& info, + VmaDefragmentationContext ctx); + void RecordDefragmentationEnd(uint32_t frameIndex, + VmaDefragmentationContext ctx); + void RecordSetPoolName(uint32_t frameIndex, + VmaPool pool, + const char* name); + +private: + struct CallParams + { + uint32_t threadId; + double time; + }; + + class UserDataString + { + public: + UserDataString(VmaAllocationCreateFlags allocFlags, const void* pUserData); + const char* GetString() const { return m_Str; } + + private: + char m_PtrStr[17]; + const char* m_Str; + }; + + bool m_UseMutex; + VmaRecordFlags m_Flags; + FILE* m_File; + VMA_MUTEX m_FileMutex; + std::chrono::time_point m_RecordingStartTime; + + void GetBasicParams(CallParams& outParams); + + // T must be a pointer type, e.g. VmaAllocation, VmaPool. + template + void PrintPointerList(uint64_t count, const T* pItems) + { + if(count) + { + fprintf(m_File, "%p", pItems[0]); + for(uint64_t i = 1; i < count; ++i) + { + fprintf(m_File, " %p", pItems[i]); + } + } + } + + void PrintPointerList(uint64_t count, const VmaAllocation* pItems); + void Flush(); +}; + +#endif // #if VMA_RECORDING_ENABLED + +/* +Thread-safe wrapper over VmaPoolAllocator free list, for allocation of VmaAllocation_T objects. +*/ +class VmaAllocationObjectAllocator +{ + VMA_CLASS_NO_COPY(VmaAllocationObjectAllocator) +public: + VmaAllocationObjectAllocator(const VkAllocationCallbacks* pAllocationCallbacks); + + template VmaAllocation Allocate(Types... args); + void Free(VmaAllocation hAlloc); + +private: + VMA_MUTEX m_Mutex; + VmaPoolAllocator m_Allocator; +}; + +struct VmaCurrentBudgetData +{ + VMA_ATOMIC_UINT64 m_BlockBytes[VK_MAX_MEMORY_HEAPS]; + VMA_ATOMIC_UINT64 m_AllocationBytes[VK_MAX_MEMORY_HEAPS]; + +#if VMA_MEMORY_BUDGET + VMA_ATOMIC_UINT32 m_OperationsSinceBudgetFetch; + VMA_RW_MUTEX m_BudgetMutex; + uint64_t m_VulkanUsage[VK_MAX_MEMORY_HEAPS]; + uint64_t m_VulkanBudget[VK_MAX_MEMORY_HEAPS]; + uint64_t m_BlockBytesAtBudgetFetch[VK_MAX_MEMORY_HEAPS]; +#endif // #if VMA_MEMORY_BUDGET + + VmaCurrentBudgetData() + { + for(uint32_t heapIndex = 0; heapIndex < VK_MAX_MEMORY_HEAPS; ++heapIndex) + { + m_BlockBytes[heapIndex] = 0; + m_AllocationBytes[heapIndex] = 0; +#if VMA_MEMORY_BUDGET + m_VulkanUsage[heapIndex] = 0; + m_VulkanBudget[heapIndex] = 0; + m_BlockBytesAtBudgetFetch[heapIndex] = 0; +#endif + } + +#if VMA_MEMORY_BUDGET + m_OperationsSinceBudgetFetch = 0; +#endif + } + + void AddAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) + { + m_AllocationBytes[heapIndex] += allocationSize; +#if VMA_MEMORY_BUDGET + ++m_OperationsSinceBudgetFetch; +#endif + } + + void RemoveAllocation(uint32_t heapIndex, VkDeviceSize allocationSize) + { + VMA_ASSERT(m_AllocationBytes[heapIndex] >= allocationSize); // DELME + m_AllocationBytes[heapIndex] -= allocationSize; +#if VMA_MEMORY_BUDGET + ++m_OperationsSinceBudgetFetch; +#endif + } +}; + +// Main allocator object. +struct VmaAllocator_T +{ + VMA_CLASS_NO_COPY(VmaAllocator_T) +public: + bool m_UseMutex; + uint32_t m_VulkanApiVersion; + bool m_UseKhrDedicatedAllocation; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). + bool m_UseKhrBindMemory2; // Can be set only if m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0). + bool m_UseExtMemoryBudget; + bool m_UseAmdDeviceCoherentMemory; + bool m_UseKhrBufferDeviceAddress; + VkDevice m_hDevice; + VkInstance m_hInstance; + bool m_AllocationCallbacksSpecified; + VkAllocationCallbacks m_AllocationCallbacks; + VmaDeviceMemoryCallbacks m_DeviceMemoryCallbacks; + VmaAllocationObjectAllocator m_AllocationObjectAllocator; + + // Each bit (1 << i) is set if HeapSizeLimit is enabled for that heap, so cannot allocate more than the heap size. + uint32_t m_HeapSizeLimitMask; + + VkPhysicalDeviceProperties m_PhysicalDeviceProperties; + VkPhysicalDeviceMemoryProperties m_MemProps; + + // Default pools. + VmaBlockVector* m_pBlockVectors[VK_MAX_MEMORY_TYPES]; + + // Each vector is sorted by memory (handle value). + typedef VmaVector< VmaAllocation, VmaStlAllocator > AllocationVectorType; + AllocationVectorType* m_pDedicatedAllocations[VK_MAX_MEMORY_TYPES]; + VMA_RW_MUTEX m_DedicatedAllocationsMutex[VK_MAX_MEMORY_TYPES]; + + VmaCurrentBudgetData m_Budget; + + VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo); + VkResult Init(const VmaAllocatorCreateInfo* pCreateInfo); + ~VmaAllocator_T(); + + const VkAllocationCallbacks* GetAllocationCallbacks() const + { + return m_AllocationCallbacksSpecified ? &m_AllocationCallbacks : 0; + } + const VmaVulkanFunctions& GetVulkanFunctions() const + { + return m_VulkanFunctions; + } + + VkPhysicalDevice GetPhysicalDevice() const { return m_PhysicalDevice; } + + VkDeviceSize GetBufferImageGranularity() const + { + return VMA_MAX( + static_cast(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY), + m_PhysicalDeviceProperties.limits.bufferImageGranularity); + } + + uint32_t GetMemoryHeapCount() const { return m_MemProps.memoryHeapCount; } + uint32_t GetMemoryTypeCount() const { return m_MemProps.memoryTypeCount; } + + uint32_t MemoryTypeIndexToHeapIndex(uint32_t memTypeIndex) const + { + VMA_ASSERT(memTypeIndex < m_MemProps.memoryTypeCount); + return m_MemProps.memoryTypes[memTypeIndex].heapIndex; + } + // True when specific memory type is HOST_VISIBLE but not HOST_COHERENT. + bool IsMemoryTypeNonCoherent(uint32_t memTypeIndex) const + { + return (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT)) == + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + } + // Minimum alignment for all allocations in specific memory type. + VkDeviceSize GetMemoryTypeMinAlignment(uint32_t memTypeIndex) const + { + return IsMemoryTypeNonCoherent(memTypeIndex) ? + VMA_MAX((VkDeviceSize)VMA_DEBUG_ALIGNMENT, m_PhysicalDeviceProperties.limits.nonCoherentAtomSize) : + (VkDeviceSize)VMA_DEBUG_ALIGNMENT; + } + + bool IsIntegratedGpu() const + { + return m_PhysicalDeviceProperties.deviceType == VK_PHYSICAL_DEVICE_TYPE_INTEGRATED_GPU; + } + + uint32_t GetGlobalMemoryTypeBits() const { return m_GlobalMemoryTypeBits; } + +#if VMA_RECORDING_ENABLED + VmaRecorder* GetRecorder() const { return m_pRecorder; } +#endif + + void GetBufferMemoryRequirements( + VkBuffer hBuffer, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const; + void GetImageMemoryRequirements( + VkImage hImage, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const; + + // Main allocation function. + VkResult AllocateMemory( + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, // UINT32_MAX when unknown. + VkImage dedicatedImage, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations); + + // Main deallocation function. + void FreeMemory( + size_t allocationCount, + const VmaAllocation* pAllocations); + + VkResult ResizeAllocation( + const VmaAllocation alloc, + VkDeviceSize newSize); + + void CalculateStats(VmaStats* pStats); + + void GetBudget( + VmaBudget* outBudget, uint32_t firstHeap, uint32_t heapCount); + +#if VMA_STATS_STRING_ENABLED + void PrintDetailedMap(class VmaJsonWriter& json); +#endif + + VkResult DefragmentationBegin( + const VmaDefragmentationInfo2& info, + VmaDefragmentationStats* pStats, + VmaDefragmentationContext* pContext); + VkResult DefragmentationEnd( + VmaDefragmentationContext context); + + VkResult DefragmentationPassBegin( + VmaDefragmentationPassInfo* pInfo, + VmaDefragmentationContext context); + VkResult DefragmentationPassEnd( + VmaDefragmentationContext context); + + void GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo); + bool TouchAllocation(VmaAllocation hAllocation); + + VkResult CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool); + void DestroyPool(VmaPool pool); + void GetPoolStats(VmaPool pool, VmaPoolStats* pPoolStats); + + void SetCurrentFrameIndex(uint32_t frameIndex); + uint32_t GetCurrentFrameIndex() const { return m_CurrentFrameIndex.load(); } + + void MakePoolAllocationsLost( + VmaPool hPool, + size_t* pLostAllocationCount); + VkResult CheckPoolCorruption(VmaPool hPool); + VkResult CheckCorruption(uint32_t memoryTypeBits); + + void CreateLostAllocation(VmaAllocation* pAllocation); + + // Call to Vulkan function vkAllocateMemory with accompanying bookkeeping. + VkResult AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory); + // Call to Vulkan function vkFreeMemory with accompanying bookkeeping. + void FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory); + // Call to Vulkan function vkBindBufferMemory or vkBindBufferMemory2KHR. + VkResult BindVulkanBuffer( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkBuffer buffer, + const void* pNext); + // Call to Vulkan function vkBindImageMemory or vkBindImageMemory2KHR. + VkResult BindVulkanImage( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkImage image, + const void* pNext); + + VkResult Map(VmaAllocation hAllocation, void** ppData); + void Unmap(VmaAllocation hAllocation); + + VkResult BindBufferMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext); + VkResult BindImageMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext); + + VkResult FlushOrInvalidateAllocation( + VmaAllocation hAllocation, + VkDeviceSize offset, VkDeviceSize size, + VMA_CACHE_OPERATION op); + VkResult FlushOrInvalidateAllocations( + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, const VkDeviceSize* sizes, + VMA_CACHE_OPERATION op); + + void FillAllocation(const VmaAllocation hAllocation, uint8_t pattern); + + /* + Returns bit mask of memory types that can support defragmentation on GPU as + they support creation of required buffer for copy operations. + */ + uint32_t GetGpuDefragmentationMemoryTypeBits(); + +private: + VkDeviceSize m_PreferredLargeHeapBlockSize; + + VkPhysicalDevice m_PhysicalDevice; + VMA_ATOMIC_UINT32 m_CurrentFrameIndex; + VMA_ATOMIC_UINT32 m_GpuDefragmentationMemoryTypeBits; // UINT32_MAX means uninitialized. + + VMA_RW_MUTEX m_PoolsMutex; + // Protected by m_PoolsMutex. Sorted by pointer value. + VmaVector > m_Pools; + uint32_t m_NextPoolId; + + VmaVulkanFunctions m_VulkanFunctions; + + // Global bit mask AND-ed with any memoryTypeBits to disallow certain memory types. + uint32_t m_GlobalMemoryTypeBits; + +#if VMA_RECORDING_ENABLED + VmaRecorder* m_pRecorder; +#endif + + void ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions); + +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 + void ImportVulkanFunctions_Static(); +#endif + + void ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions); + +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + void ImportVulkanFunctions_Dynamic(); +#endif + + void ValidateVulkanFunctions(); + + VkDeviceSize CalcPreferredBlockSize(uint32_t memTypeIndex); + + VkResult AllocateMemoryOfType( + VkDeviceSize size, + VkDeviceSize alignment, + bool dedicatedAllocation, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, + VkImage dedicatedImage, + const VmaAllocationCreateInfo& createInfo, + uint32_t memTypeIndex, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations); + + // Helper function only to be used inside AllocateDedicatedMemory. + VkResult AllocateDedicatedMemoryPage( + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + const VkMemoryAllocateInfo& allocInfo, + bool map, + bool isUserDataString, + void* pUserData, + VmaAllocation* pAllocation); + + // Allocates and registers new VkDeviceMemory specifically for dedicated allocations. + VkResult AllocateDedicatedMemory( + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + bool withinBudget, + bool map, + bool isUserDataString, + void* pUserData, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, + VkImage dedicatedImage, + size_t allocationCount, + VmaAllocation* pAllocations); + + void FreeDedicatedMemory(const VmaAllocation allocation); + + /* + Calculates and returns bit mask of memory types that can support defragmentation + on GPU as they support creation of required buffer for copy operations. + */ + uint32_t CalculateGpuDefragmentationMemoryTypeBits() const; + + uint32_t CalculateGlobalMemoryTypeBits() const; + + bool GetFlushOrInvalidateRange( + VmaAllocation allocation, + VkDeviceSize offset, VkDeviceSize size, + VkMappedMemoryRange& outRange) const; + +#if VMA_MEMORY_BUDGET + void UpdateVulkanBudget(); +#endif // #if VMA_MEMORY_BUDGET +}; + +//////////////////////////////////////////////////////////////////////////////// +// Memory allocation #2 after VmaAllocator_T definition + +static void* VmaMalloc(VmaAllocator hAllocator, size_t size, size_t alignment) +{ + return VmaMalloc(&hAllocator->m_AllocationCallbacks, size, alignment); +} + +static void VmaFree(VmaAllocator hAllocator, void* ptr) +{ + VmaFree(&hAllocator->m_AllocationCallbacks, ptr); +} + +template +static T* VmaAllocate(VmaAllocator hAllocator) +{ + return (T*)VmaMalloc(hAllocator, sizeof(T), VMA_ALIGN_OF(T)); +} + +template +static T* VmaAllocateArray(VmaAllocator hAllocator, size_t count) +{ + return (T*)VmaMalloc(hAllocator, sizeof(T) * count, VMA_ALIGN_OF(T)); +} + +template +static void vma_delete(VmaAllocator hAllocator, T* ptr) +{ + if(ptr != VMA_NULL) + { + ptr->~T(); + VmaFree(hAllocator, ptr); + } +} + +template +static void vma_delete_array(VmaAllocator hAllocator, T* ptr, size_t count) +{ + if(ptr != VMA_NULL) + { + for(size_t i = count; i--; ) + ptr[i].~T(); + VmaFree(hAllocator, ptr); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaStringBuilder + +#if VMA_STATS_STRING_ENABLED + +class VmaStringBuilder +{ +public: + VmaStringBuilder(VmaAllocator alloc) : m_Data(VmaStlAllocator(alloc->GetAllocationCallbacks())) { } + size_t GetLength() const { return m_Data.size(); } + const char* GetData() const { return m_Data.data(); } + + void Add(char ch) { m_Data.push_back(ch); } + void Add(const char* pStr); + void AddNewLine() { Add('\n'); } + void AddNumber(uint32_t num); + void AddNumber(uint64_t num); + void AddPointer(const void* ptr); + +private: + VmaVector< char, VmaStlAllocator > m_Data; +}; + +void VmaStringBuilder::Add(const char* pStr) +{ + const size_t strLen = strlen(pStr); + if(strLen > 0) + { + const size_t oldCount = m_Data.size(); + m_Data.resize(oldCount + strLen); + memcpy(m_Data.data() + oldCount, pStr, strLen); + } +} + +void VmaStringBuilder::AddNumber(uint32_t num) +{ + char buf[11]; + buf[10] = '\0'; + char *p = &buf[10]; + do + { + *--p = '0' + (num % 10); + num /= 10; + } + while(num); + Add(p); +} + +void VmaStringBuilder::AddNumber(uint64_t num) +{ + char buf[21]; + buf[20] = '\0'; + char *p = &buf[20]; + do + { + *--p = '0' + (num % 10); + num /= 10; + } + while(num); + Add(p); +} + +void VmaStringBuilder::AddPointer(const void* ptr) +{ + char buf[21]; + VmaPtrToStr(buf, sizeof(buf), ptr); + Add(buf); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +// VmaJsonWriter + +#if VMA_STATS_STRING_ENABLED + +class VmaJsonWriter +{ + VMA_CLASS_NO_COPY(VmaJsonWriter) +public: + VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb); + ~VmaJsonWriter(); + + void BeginObject(bool singleLine = false); + void EndObject(); + + void BeginArray(bool singleLine = false); + void EndArray(); + + void WriteString(const char* pStr); + void BeginString(const char* pStr = VMA_NULL); + void ContinueString(const char* pStr); + void ContinueString(uint32_t n); + void ContinueString(uint64_t n); + void ContinueString_Pointer(const void* ptr); + void EndString(const char* pStr = VMA_NULL); + + void WriteNumber(uint32_t n); + void WriteNumber(uint64_t n); + void WriteBool(bool b); + void WriteNull(); + +private: + static const char* const INDENT; + + enum COLLECTION_TYPE + { + COLLECTION_TYPE_OBJECT, + COLLECTION_TYPE_ARRAY, + }; + struct StackItem + { + COLLECTION_TYPE type; + uint32_t valueCount; + bool singleLineMode; + }; + + VmaStringBuilder& m_SB; + VmaVector< StackItem, VmaStlAllocator > m_Stack; + bool m_InsideString; + + void BeginValue(bool isString); + void WriteIndent(bool oneLess = false); +}; + +const char* const VmaJsonWriter::INDENT = " "; + +VmaJsonWriter::VmaJsonWriter(const VkAllocationCallbacks* pAllocationCallbacks, VmaStringBuilder& sb) : + m_SB(sb), + m_Stack(VmaStlAllocator(pAllocationCallbacks)), + m_InsideString(false) +{ +} + +VmaJsonWriter::~VmaJsonWriter() +{ + VMA_ASSERT(!m_InsideString); + VMA_ASSERT(m_Stack.empty()); +} + +void VmaJsonWriter::BeginObject(bool singleLine) +{ + VMA_ASSERT(!m_InsideString); + + BeginValue(false); + m_SB.Add('{'); + + StackItem item; + item.type = COLLECTION_TYPE_OBJECT; + item.valueCount = 0; + item.singleLineMode = singleLine; + m_Stack.push_back(item); +} + +void VmaJsonWriter::EndObject() +{ + VMA_ASSERT(!m_InsideString); + + WriteIndent(true); + m_SB.Add('}'); + + VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_OBJECT); + m_Stack.pop_back(); +} + +void VmaJsonWriter::BeginArray(bool singleLine) +{ + VMA_ASSERT(!m_InsideString); + + BeginValue(false); + m_SB.Add('['); + + StackItem item; + item.type = COLLECTION_TYPE_ARRAY; + item.valueCount = 0; + item.singleLineMode = singleLine; + m_Stack.push_back(item); +} + +void VmaJsonWriter::EndArray() +{ + VMA_ASSERT(!m_InsideString); + + WriteIndent(true); + m_SB.Add(']'); + + VMA_ASSERT(!m_Stack.empty() && m_Stack.back().type == COLLECTION_TYPE_ARRAY); + m_Stack.pop_back(); +} + +void VmaJsonWriter::WriteString(const char* pStr) +{ + BeginString(pStr); + EndString(); +} + +void VmaJsonWriter::BeginString(const char* pStr) +{ + VMA_ASSERT(!m_InsideString); + + BeginValue(true); + m_SB.Add('"'); + m_InsideString = true; + if(pStr != VMA_NULL && pStr[0] != '\0') + { + ContinueString(pStr); + } +} + +void VmaJsonWriter::ContinueString(const char* pStr) +{ + VMA_ASSERT(m_InsideString); + + const size_t strLen = strlen(pStr); + for(size_t i = 0; i < strLen; ++i) + { + char ch = pStr[i]; + if(ch == '\\') + { + m_SB.Add("\\\\"); + } + else if(ch == '"') + { + m_SB.Add("\\\""); + } + else if(ch >= 32) + { + m_SB.Add(ch); + } + else switch(ch) + { + case '\b': + m_SB.Add("\\b"); + break; + case '\f': + m_SB.Add("\\f"); + break; + case '\n': + m_SB.Add("\\n"); + break; + case '\r': + m_SB.Add("\\r"); + break; + case '\t': + m_SB.Add("\\t"); + break; + default: + VMA_ASSERT(0 && "Character not currently supported."); + break; + } + } +} + +void VmaJsonWriter::ContinueString(uint32_t n) +{ + VMA_ASSERT(m_InsideString); + m_SB.AddNumber(n); +} + +void VmaJsonWriter::ContinueString(uint64_t n) +{ + VMA_ASSERT(m_InsideString); + m_SB.AddNumber(n); +} + +void VmaJsonWriter::ContinueString_Pointer(const void* ptr) +{ + VMA_ASSERT(m_InsideString); + m_SB.AddPointer(ptr); +} + +void VmaJsonWriter::EndString(const char* pStr) +{ + VMA_ASSERT(m_InsideString); + if(pStr != VMA_NULL && pStr[0] != '\0') + { + ContinueString(pStr); + } + m_SB.Add('"'); + m_InsideString = false; +} + +void VmaJsonWriter::WriteNumber(uint32_t n) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.AddNumber(n); +} + +void VmaJsonWriter::WriteNumber(uint64_t n) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.AddNumber(n); +} + +void VmaJsonWriter::WriteBool(bool b) +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.Add(b ? "true" : "false"); +} + +void VmaJsonWriter::WriteNull() +{ + VMA_ASSERT(!m_InsideString); + BeginValue(false); + m_SB.Add("null"); +} + +void VmaJsonWriter::BeginValue(bool isString) +{ + if(!m_Stack.empty()) + { + StackItem& currItem = m_Stack.back(); + if(currItem.type == COLLECTION_TYPE_OBJECT && + currItem.valueCount % 2 == 0) + { + VMA_ASSERT(isString); + } + + if(currItem.type == COLLECTION_TYPE_OBJECT && + currItem.valueCount % 2 != 0) + { + m_SB.Add(": "); + } + else if(currItem.valueCount > 0) + { + m_SB.Add(", "); + WriteIndent(); + } + else + { + WriteIndent(); + } + ++currItem.valueCount; + } +} + +void VmaJsonWriter::WriteIndent(bool oneLess) +{ + if(!m_Stack.empty() && !m_Stack.back().singleLineMode) + { + m_SB.AddNewLine(); + + size_t count = m_Stack.size(); + if(count > 0 && oneLess) + { + --count; + } + for(size_t i = 0; i < count; ++i) + { + m_SB.Add(INDENT); + } + } +} + +#endif // #if VMA_STATS_STRING_ENABLED + +//////////////////////////////////////////////////////////////////////////////// + +void VmaAllocation_T::SetUserData(VmaAllocator hAllocator, void* pUserData) +{ + if(IsUserDataString()) + { + VMA_ASSERT(pUserData == VMA_NULL || pUserData != m_pUserData); + + FreeUserDataString(hAllocator); + + if(pUserData != VMA_NULL) + { + m_pUserData = VmaCreateStringCopy(hAllocator->GetAllocationCallbacks(), (const char*)pUserData); + } + } + else + { + m_pUserData = pUserData; + } +} + +void VmaAllocation_T::ChangeBlockAllocation( + VmaAllocator hAllocator, + VmaDeviceMemoryBlock* block, + VkDeviceSize offset) +{ + VMA_ASSERT(block != VMA_NULL); + VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); + + // Move mapping reference counter from old block to new block. + if(block != m_BlockAllocation.m_Block) + { + uint32_t mapRefCount = m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP; + if(IsPersistentMap()) + ++mapRefCount; + m_BlockAllocation.m_Block->Unmap(hAllocator, mapRefCount); + block->Map(hAllocator, mapRefCount, VMA_NULL); + } + + m_BlockAllocation.m_Block = block; + m_BlockAllocation.m_Offset = offset; +} + +void VmaAllocation_T::ChangeOffset(VkDeviceSize newOffset) +{ + VMA_ASSERT(m_Type == ALLOCATION_TYPE_BLOCK); + m_BlockAllocation.m_Offset = newOffset; +} + +VkDeviceSize VmaAllocation_T::GetOffset() const +{ + switch(m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_Offset; + case ALLOCATION_TYPE_DEDICATED: + return 0; + default: + VMA_ASSERT(0); + return 0; + } +} + +VkDeviceMemory VmaAllocation_T::GetMemory() const +{ + switch(m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_Block->GetDeviceMemory(); + case ALLOCATION_TYPE_DEDICATED: + return m_DedicatedAllocation.m_hMemory; + default: + VMA_ASSERT(0); + return VK_NULL_HANDLE; + } +} + +void* VmaAllocation_T::GetMappedData() const +{ + switch(m_Type) + { + case ALLOCATION_TYPE_BLOCK: + if(m_MapCount != 0) + { + void* pBlockData = m_BlockAllocation.m_Block->GetMappedData(); + VMA_ASSERT(pBlockData != VMA_NULL); + return (char*)pBlockData + m_BlockAllocation.m_Offset; + } + else + { + return VMA_NULL; + } + break; + case ALLOCATION_TYPE_DEDICATED: + VMA_ASSERT((m_DedicatedAllocation.m_pMappedData != VMA_NULL) == (m_MapCount != 0)); + return m_DedicatedAllocation.m_pMappedData; + default: + VMA_ASSERT(0); + return VMA_NULL; + } +} + +bool VmaAllocation_T::CanBecomeLost() const +{ + switch(m_Type) + { + case ALLOCATION_TYPE_BLOCK: + return m_BlockAllocation.m_CanBecomeLost; + case ALLOCATION_TYPE_DEDICATED: + return false; + default: + VMA_ASSERT(0); + return false; + } +} + +bool VmaAllocation_T::MakeLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) +{ + VMA_ASSERT(CanBecomeLost()); + + /* + Warning: This is a carefully designed algorithm. + Do not modify unless you really know what you're doing :) + */ + uint32_t localLastUseFrameIndex = GetLastUseFrameIndex(); + for(;;) + { + if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) + { + VMA_ASSERT(0); + return false; + } + else if(localLastUseFrameIndex + frameInUseCount >= currentFrameIndex) + { + return false; + } + else // Last use time earlier than current time. + { + if(CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, VMA_FRAME_INDEX_LOST)) + { + // Setting hAllocation.LastUseFrameIndex atomic to VMA_FRAME_INDEX_LOST is enough to mark it as LOST. + // Calling code just needs to unregister this allocation in owning VmaDeviceMemoryBlock. + return true; + } + } + } +} + +#if VMA_STATS_STRING_ENABLED + +// Correspond to values of enum VmaSuballocationType. +static const char* VMA_SUBALLOCATION_TYPE_NAMES[] = { + "FREE", + "UNKNOWN", + "BUFFER", + "IMAGE_UNKNOWN", + "IMAGE_LINEAR", + "IMAGE_OPTIMAL", +}; + +void VmaAllocation_T::PrintParameters(class VmaJsonWriter& json) const +{ + json.WriteString("Type"); + json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[m_SuballocationType]); + + json.WriteString("Size"); + json.WriteNumber(m_Size); + + if(m_pUserData != VMA_NULL) + { + json.WriteString("UserData"); + if(IsUserDataString()) + { + json.WriteString((const char*)m_pUserData); + } + else + { + json.BeginString(); + json.ContinueString_Pointer(m_pUserData); + json.EndString(); + } + } + + json.WriteString("CreationFrameIndex"); + json.WriteNumber(m_CreationFrameIndex); + + json.WriteString("LastUseFrameIndex"); + json.WriteNumber(GetLastUseFrameIndex()); + + if(m_BufferImageUsage != 0) + { + json.WriteString("Usage"); + json.WriteNumber(m_BufferImageUsage); + } +} + +#endif + +void VmaAllocation_T::FreeUserDataString(VmaAllocator hAllocator) +{ + VMA_ASSERT(IsUserDataString()); + VmaFreeString(hAllocator->GetAllocationCallbacks(), (char*)m_pUserData); + m_pUserData = VMA_NULL; +} + +void VmaAllocation_T::BlockAllocMap() +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); + + if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) < 0x7F) + { + ++m_MapCount; + } + else + { + VMA_ASSERT(0 && "Allocation mapped too many times simultaneously."); + } +} + +void VmaAllocation_T::BlockAllocUnmap() +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_BLOCK); + + if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) != 0) + { + --m_MapCount; + } + else + { + VMA_ASSERT(0 && "Unmapping allocation not previously mapped."); + } +} + +VkResult VmaAllocation_T::DedicatedAllocMap(VmaAllocator hAllocator, void** ppData) +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); + + if(m_MapCount != 0) + { + if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) < 0x7F) + { + VMA_ASSERT(m_DedicatedAllocation.m_pMappedData != VMA_NULL); + *ppData = m_DedicatedAllocation.m_pMappedData; + ++m_MapCount; + return VK_SUCCESS; + } + else + { + VMA_ASSERT(0 && "Dedicated allocation mapped too many times simultaneously."); + return VK_ERROR_MEMORY_MAP_FAILED; + } + } + else + { + VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( + hAllocator->m_hDevice, + m_DedicatedAllocation.m_hMemory, + 0, // offset + VK_WHOLE_SIZE, + 0, // flags + ppData); + if(result == VK_SUCCESS) + { + m_DedicatedAllocation.m_pMappedData = *ppData; + m_MapCount = 1; + } + return result; + } +} + +void VmaAllocation_T::DedicatedAllocUnmap(VmaAllocator hAllocator) +{ + VMA_ASSERT(GetType() == ALLOCATION_TYPE_DEDICATED); + + if((m_MapCount & ~MAP_COUNT_FLAG_PERSISTENT_MAP) != 0) + { + --m_MapCount; + if(m_MapCount == 0) + { + m_DedicatedAllocation.m_pMappedData = VMA_NULL; + (*hAllocator->GetVulkanFunctions().vkUnmapMemory)( + hAllocator->m_hDevice, + m_DedicatedAllocation.m_hMemory); + } + } + else + { + VMA_ASSERT(0 && "Unmapping dedicated allocation not previously mapped."); + } +} + +#if VMA_STATS_STRING_ENABLED + +static void VmaPrintStatInfo(VmaJsonWriter& json, const VmaStatInfo& stat) +{ + json.BeginObject(); + + json.WriteString("Blocks"); + json.WriteNumber(stat.blockCount); + + json.WriteString("Allocations"); + json.WriteNumber(stat.allocationCount); + + json.WriteString("UnusedRanges"); + json.WriteNumber(stat.unusedRangeCount); + + json.WriteString("UsedBytes"); + json.WriteNumber(stat.usedBytes); + + json.WriteString("UnusedBytes"); + json.WriteNumber(stat.unusedBytes); + + if(stat.allocationCount > 1) + { + json.WriteString("AllocationSize"); + json.BeginObject(true); + json.WriteString("Min"); + json.WriteNumber(stat.allocationSizeMin); + json.WriteString("Avg"); + json.WriteNumber(stat.allocationSizeAvg); + json.WriteString("Max"); + json.WriteNumber(stat.allocationSizeMax); + json.EndObject(); + } + + if(stat.unusedRangeCount > 1) + { + json.WriteString("UnusedRangeSize"); + json.BeginObject(true); + json.WriteString("Min"); + json.WriteNumber(stat.unusedRangeSizeMin); + json.WriteString("Avg"); + json.WriteNumber(stat.unusedRangeSizeAvg); + json.WriteString("Max"); + json.WriteNumber(stat.unusedRangeSizeMax); + json.EndObject(); + } + + json.EndObject(); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +struct VmaSuballocationItemSizeLess +{ + bool operator()( + const VmaSuballocationList::iterator lhs, + const VmaSuballocationList::iterator rhs) const + { + return lhs->size < rhs->size; + } + bool operator()( + const VmaSuballocationList::iterator lhs, + VkDeviceSize rhsSize) const + { + return lhs->size < rhsSize; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// class VmaBlockMetadata + +VmaBlockMetadata::VmaBlockMetadata(VmaAllocator hAllocator) : + m_Size(0), + m_pAllocationCallbacks(hAllocator->GetAllocationCallbacks()) +{ +} + +#if VMA_STATS_STRING_ENABLED + +void VmaBlockMetadata::PrintDetailedMap_Begin(class VmaJsonWriter& json, + VkDeviceSize unusedBytes, + size_t allocationCount, + size_t unusedRangeCount) const +{ + json.BeginObject(); + + json.WriteString("TotalBytes"); + json.WriteNumber(GetSize()); + + json.WriteString("UnusedBytes"); + json.WriteNumber(unusedBytes); + + json.WriteString("Allocations"); + json.WriteNumber((uint64_t)allocationCount); + + json.WriteString("UnusedRanges"); + json.WriteNumber((uint64_t)unusedRangeCount); + + json.WriteString("Suballocations"); + json.BeginArray(); +} + +void VmaBlockMetadata::PrintDetailedMap_Allocation(class VmaJsonWriter& json, + VkDeviceSize offset, + VmaAllocation hAllocation) const +{ + json.BeginObject(true); + + json.WriteString("Offset"); + json.WriteNumber(offset); + + hAllocation->PrintParameters(json); + + json.EndObject(); +} + +void VmaBlockMetadata::PrintDetailedMap_UnusedRange(class VmaJsonWriter& json, + VkDeviceSize offset, + VkDeviceSize size) const +{ + json.BeginObject(true); + + json.WriteString("Offset"); + json.WriteNumber(offset); + + json.WriteString("Type"); + json.WriteString(VMA_SUBALLOCATION_TYPE_NAMES[VMA_SUBALLOCATION_TYPE_FREE]); + + json.WriteString("Size"); + json.WriteNumber(size); + + json.EndObject(); +} + +void VmaBlockMetadata::PrintDetailedMap_End(class VmaJsonWriter& json) const +{ + json.EndArray(); + json.EndObject(); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +// class VmaBlockMetadata_Generic + +VmaBlockMetadata_Generic::VmaBlockMetadata_Generic(VmaAllocator hAllocator) : + VmaBlockMetadata(hAllocator), + m_FreeCount(0), + m_SumFreeSize(0), + m_Suballocations(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_FreeSuballocationsBySize(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +{ +} + +VmaBlockMetadata_Generic::~VmaBlockMetadata_Generic() +{ +} + +void VmaBlockMetadata_Generic::Init(VkDeviceSize size) +{ + VmaBlockMetadata::Init(size); + + m_FreeCount = 1; + m_SumFreeSize = size; + + VmaSuballocation suballoc = {}; + suballoc.offset = 0; + suballoc.size = size; + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.hAllocation = VK_NULL_HANDLE; + + VMA_ASSERT(size > VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER); + m_Suballocations.push_back(suballoc); + VmaSuballocationList::iterator suballocItem = m_Suballocations.end(); + --suballocItem; + m_FreeSuballocationsBySize.push_back(suballocItem); +} + +bool VmaBlockMetadata_Generic::Validate() const +{ + VMA_VALIDATE(!m_Suballocations.empty()); + + // Expected offset of new suballocation as calculated from previous ones. + VkDeviceSize calculatedOffset = 0; + // Expected number of free suballocations as calculated from traversing their list. + uint32_t calculatedFreeCount = 0; + // Expected sum size of free suballocations as calculated from traversing their list. + VkDeviceSize calculatedSumFreeSize = 0; + // Expected number of free suballocations that should be registered in + // m_FreeSuballocationsBySize calculated from traversing their list. + size_t freeSuballocationsToRegister = 0; + // True if previous visited suballocation was free. + bool prevFree = false; + + for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); + suballocItem != m_Suballocations.cend(); + ++suballocItem) + { + const VmaSuballocation& subAlloc = *suballocItem; + + // Actual offset of this suballocation doesn't match expected one. + VMA_VALIDATE(subAlloc.offset == calculatedOffset); + + const bool currFree = (subAlloc.type == VMA_SUBALLOCATION_TYPE_FREE); + // Two adjacent free suballocations are invalid. They should be merged. + VMA_VALIDATE(!prevFree || !currFree); + + VMA_VALIDATE(currFree == (subAlloc.hAllocation == VK_NULL_HANDLE)); + + if(currFree) + { + calculatedSumFreeSize += subAlloc.size; + ++calculatedFreeCount; + if(subAlloc.size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + ++freeSuballocationsToRegister; + } + + // Margin required between allocations - every free space must be at least that large. + VMA_VALIDATE(subAlloc.size >= VMA_DEBUG_MARGIN); + } + else + { + VMA_VALIDATE(subAlloc.hAllocation->GetOffset() == subAlloc.offset); + VMA_VALIDATE(subAlloc.hAllocation->GetSize() == subAlloc.size); + + // Margin required between allocations - previous allocation must be free. + VMA_VALIDATE(VMA_DEBUG_MARGIN == 0 || prevFree); + } + + calculatedOffset += subAlloc.size; + prevFree = currFree; + } + + // Number of free suballocations registered in m_FreeSuballocationsBySize doesn't + // match expected one. + VMA_VALIDATE(m_FreeSuballocationsBySize.size() == freeSuballocationsToRegister); + + VkDeviceSize lastSize = 0; + for(size_t i = 0; i < m_FreeSuballocationsBySize.size(); ++i) + { + VmaSuballocationList::iterator suballocItem = m_FreeSuballocationsBySize[i]; + + // Only free suballocations can be registered in m_FreeSuballocationsBySize. + VMA_VALIDATE(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE); + // They must be sorted by size ascending. + VMA_VALIDATE(suballocItem->size >= lastSize); + + lastSize = suballocItem->size; + } + + // Check if totals match calculacted values. + VMA_VALIDATE(ValidateFreeSuballocationList()); + VMA_VALIDATE(calculatedOffset == GetSize()); + VMA_VALIDATE(calculatedSumFreeSize == m_SumFreeSize); + VMA_VALIDATE(calculatedFreeCount == m_FreeCount); + + return true; +} + +VkDeviceSize VmaBlockMetadata_Generic::GetUnusedRangeSizeMax() const +{ + if(!m_FreeSuballocationsBySize.empty()) + { + return m_FreeSuballocationsBySize.back()->size; + } + else + { + return 0; + } +} + +bool VmaBlockMetadata_Generic::IsEmpty() const +{ + return (m_Suballocations.size() == 1) && (m_FreeCount == 1); +} + +void VmaBlockMetadata_Generic::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +{ + outInfo.blockCount = 1; + + const uint32_t rangeCount = (uint32_t)m_Suballocations.size(); + outInfo.allocationCount = rangeCount - m_FreeCount; + outInfo.unusedRangeCount = m_FreeCount; + + outInfo.unusedBytes = m_SumFreeSize; + outInfo.usedBytes = GetSize() - outInfo.unusedBytes; + + outInfo.allocationSizeMin = UINT64_MAX; + outInfo.allocationSizeMax = 0; + outInfo.unusedRangeSizeMin = UINT64_MAX; + outInfo.unusedRangeSizeMax = 0; + + for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); + suballocItem != m_Suballocations.cend(); + ++suballocItem) + { + const VmaSuballocation& suballoc = *suballocItem; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); + outInfo.allocationSizeMax = VMA_MAX(outInfo.allocationSizeMax, suballoc.size); + } + else + { + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, suballoc.size); + outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, suballoc.size); + } + } +} + +void VmaBlockMetadata_Generic::AddPoolStats(VmaPoolStats& inoutStats) const +{ + const uint32_t rangeCount = (uint32_t)m_Suballocations.size(); + + inoutStats.size += GetSize(); + inoutStats.unusedSize += m_SumFreeSize; + inoutStats.allocationCount += rangeCount - m_FreeCount; + inoutStats.unusedRangeCount += m_FreeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, GetUnusedRangeSizeMax()); +} + +#if VMA_STATS_STRING_ENABLED + +void VmaBlockMetadata_Generic::PrintDetailedMap(class VmaJsonWriter& json) const +{ + PrintDetailedMap_Begin(json, + m_SumFreeSize, // unusedBytes + m_Suballocations.size() - (size_t)m_FreeCount, // allocationCount + m_FreeCount); // unusedRangeCount + + size_t i = 0; + for(VmaSuballocationList::const_iterator suballocItem = m_Suballocations.cbegin(); + suballocItem != m_Suballocations.cend(); + ++suballocItem, ++i) + { + if(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) + { + PrintDetailedMap_UnusedRange(json, suballocItem->offset, suballocItem->size); + } + else + { + PrintDetailedMap_Allocation(json, suballocItem->offset, suballocItem->hAllocation); + } + } + + PrintDetailedMap_End(json); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +bool VmaBlockMetadata_Generic::CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(!upperAddress); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(pAllocationRequest != VMA_NULL); + VMA_HEAVY_ASSERT(Validate()); + + pAllocationRequest->type = VmaAllocationRequestType::Normal; + + // There is not enough total free space in this block to fullfill the request: Early return. + if(canMakeOtherLost == false && + m_SumFreeSize < allocSize + 2 * VMA_DEBUG_MARGIN) + { + return false; + } + + // New algorithm, efficiently searching freeSuballocationsBySize. + const size_t freeSuballocCount = m_FreeSuballocationsBySize.size(); + if(freeSuballocCount > 0) + { + if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) + { + // Find first free suballocation with size not less than allocSize + 2 * VMA_DEBUG_MARGIN. + VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( + m_FreeSuballocationsBySize.data(), + m_FreeSuballocationsBySize.data() + freeSuballocCount, + allocSize + 2 * VMA_DEBUG_MARGIN, + VmaSuballocationItemSizeLess()); + size_t index = it - m_FreeSuballocationsBySize.data(); + for(; index < freeSuballocCount; ++index) + { + if(CheckAllocation( + currentFrameIndex, + frameInUseCount, + bufferImageGranularity, + allocSize, + allocAlignment, + allocType, + m_FreeSuballocationsBySize[index], + false, // canMakeOtherLost + &pAllocationRequest->offset, + &pAllocationRequest->itemsToMakeLostCount, + &pAllocationRequest->sumFreeSize, + &pAllocationRequest->sumItemSize)) + { + pAllocationRequest->item = m_FreeSuballocationsBySize[index]; + return true; + } + } + } + else if(strategy == VMA_ALLOCATION_INTERNAL_STRATEGY_MIN_OFFSET) + { + for(VmaSuballocationList::iterator it = m_Suballocations.begin(); + it != m_Suballocations.end(); + ++it) + { + if(it->type == VMA_SUBALLOCATION_TYPE_FREE && CheckAllocation( + currentFrameIndex, + frameInUseCount, + bufferImageGranularity, + allocSize, + allocAlignment, + allocType, + it, + false, // canMakeOtherLost + &pAllocationRequest->offset, + &pAllocationRequest->itemsToMakeLostCount, + &pAllocationRequest->sumFreeSize, + &pAllocationRequest->sumItemSize)) + { + pAllocationRequest->item = it; + return true; + } + } + } + else // WORST_FIT, FIRST_FIT + { + // Search staring from biggest suballocations. + for(size_t index = freeSuballocCount; index--; ) + { + if(CheckAllocation( + currentFrameIndex, + frameInUseCount, + bufferImageGranularity, + allocSize, + allocAlignment, + allocType, + m_FreeSuballocationsBySize[index], + false, // canMakeOtherLost + &pAllocationRequest->offset, + &pAllocationRequest->itemsToMakeLostCount, + &pAllocationRequest->sumFreeSize, + &pAllocationRequest->sumItemSize)) + { + pAllocationRequest->item = m_FreeSuballocationsBySize[index]; + return true; + } + } + } + } + + if(canMakeOtherLost) + { + // Brute-force algorithm. TODO: Come up with something better. + + bool found = false; + VmaAllocationRequest tmpAllocRequest = {}; + tmpAllocRequest.type = VmaAllocationRequestType::Normal; + for(VmaSuballocationList::iterator suballocIt = m_Suballocations.begin(); + suballocIt != m_Suballocations.end(); + ++suballocIt) + { + if(suballocIt->type == VMA_SUBALLOCATION_TYPE_FREE || + suballocIt->hAllocation->CanBecomeLost()) + { + if(CheckAllocation( + currentFrameIndex, + frameInUseCount, + bufferImageGranularity, + allocSize, + allocAlignment, + allocType, + suballocIt, + canMakeOtherLost, + &tmpAllocRequest.offset, + &tmpAllocRequest.itemsToMakeLostCount, + &tmpAllocRequest.sumFreeSize, + &tmpAllocRequest.sumItemSize)) + { + if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) + { + *pAllocationRequest = tmpAllocRequest; + pAllocationRequest->item = suballocIt; + break; + } + if(!found || tmpAllocRequest.CalcCost() < pAllocationRequest->CalcCost()) + { + *pAllocationRequest = tmpAllocRequest; + pAllocationRequest->item = suballocIt; + found = true; + } + } + } + } + + return found; + } + + return false; +} + +bool VmaBlockMetadata_Generic::MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest) +{ + VMA_ASSERT(pAllocationRequest && pAllocationRequest->type == VmaAllocationRequestType::Normal); + + while(pAllocationRequest->itemsToMakeLostCount > 0) + { + if(pAllocationRequest->item->type == VMA_SUBALLOCATION_TYPE_FREE) + { + ++pAllocationRequest->item; + } + VMA_ASSERT(pAllocationRequest->item != m_Suballocations.end()); + VMA_ASSERT(pAllocationRequest->item->hAllocation != VK_NULL_HANDLE); + VMA_ASSERT(pAllocationRequest->item->hAllocation->CanBecomeLost()); + if(pAllocationRequest->item->hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + { + pAllocationRequest->item = FreeSuballocation(pAllocationRequest->item); + --pAllocationRequest->itemsToMakeLostCount; + } + else + { + return false; + } + } + + VMA_HEAVY_ASSERT(Validate()); + VMA_ASSERT(pAllocationRequest->item != m_Suballocations.end()); + VMA_ASSERT(pAllocationRequest->item->type == VMA_SUBALLOCATION_TYPE_FREE); + + return true; +} + +uint32_t VmaBlockMetadata_Generic::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) +{ + uint32_t lostAllocationCount = 0; + for(VmaSuballocationList::iterator it = m_Suballocations.begin(); + it != m_Suballocations.end(); + ++it) + { + if(it->type != VMA_SUBALLOCATION_TYPE_FREE && + it->hAllocation->CanBecomeLost() && + it->hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + { + it = FreeSuballocation(it); + ++lostAllocationCount; + } + } + return lostAllocationCount; +} + +VkResult VmaBlockMetadata_Generic::CheckCorruption(const void* pBlockData) +{ + for(VmaSuballocationList::iterator it = m_Suballocations.begin(); + it != m_Suballocations.end(); + ++it) + { + if(it->type != VMA_SUBALLOCATION_TYPE_FREE) + { + if(!VmaValidateMagicValue(pBlockData, it->offset - VMA_DEBUG_MARGIN)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if(!VmaValidateMagicValue(pBlockData, it->offset + it->size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + } + } + + return VK_SUCCESS; +} + +void VmaBlockMetadata_Generic::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation) +{ + VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); + VMA_ASSERT(request.item != m_Suballocations.end()); + VmaSuballocation& suballoc = *request.item; + // Given suballocation is a free block. + VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + // Given offset is inside this suballocation. + VMA_ASSERT(request.offset >= suballoc.offset); + const VkDeviceSize paddingBegin = request.offset - suballoc.offset; + VMA_ASSERT(suballoc.size >= paddingBegin + allocSize); + const VkDeviceSize paddingEnd = suballoc.size - paddingBegin - allocSize; + + // Unregister this free suballocation from m_FreeSuballocationsBySize and update + // it to become used. + UnregisterFreeSuballocation(request.item); + + suballoc.offset = request.offset; + suballoc.size = allocSize; + suballoc.type = type; + suballoc.hAllocation = hAllocation; + + // If there are any free bytes remaining at the end, insert new free suballocation after current one. + if(paddingEnd) + { + VmaSuballocation paddingSuballoc = {}; + paddingSuballoc.offset = request.offset + allocSize; + paddingSuballoc.size = paddingEnd; + paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + VmaSuballocationList::iterator next = request.item; + ++next; + const VmaSuballocationList::iterator paddingEndItem = + m_Suballocations.insert(next, paddingSuballoc); + RegisterFreeSuballocation(paddingEndItem); + } + + // If there are any free bytes remaining at the beginning, insert new free suballocation before current one. + if(paddingBegin) + { + VmaSuballocation paddingSuballoc = {}; + paddingSuballoc.offset = request.offset - paddingBegin; + paddingSuballoc.size = paddingBegin; + paddingSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + const VmaSuballocationList::iterator paddingBeginItem = + m_Suballocations.insert(request.item, paddingSuballoc); + RegisterFreeSuballocation(paddingBeginItem); + } + + // Update totals. + m_FreeCount = m_FreeCount - 1; + if(paddingBegin > 0) + { + ++m_FreeCount; + } + if(paddingEnd > 0) + { + ++m_FreeCount; + } + m_SumFreeSize -= allocSize; +} + +void VmaBlockMetadata_Generic::Free(const VmaAllocation allocation) +{ + for(VmaSuballocationList::iterator suballocItem = m_Suballocations.begin(); + suballocItem != m_Suballocations.end(); + ++suballocItem) + { + VmaSuballocation& suballoc = *suballocItem; + if(suballoc.hAllocation == allocation) + { + FreeSuballocation(suballocItem); + VMA_HEAVY_ASSERT(Validate()); + return; + } + } + VMA_ASSERT(0 && "Not found!"); +} + +void VmaBlockMetadata_Generic::FreeAtOffset(VkDeviceSize offset) +{ + for(VmaSuballocationList::iterator suballocItem = m_Suballocations.begin(); + suballocItem != m_Suballocations.end(); + ++suballocItem) + { + VmaSuballocation& suballoc = *suballocItem; + if(suballoc.offset == offset) + { + FreeSuballocation(suballocItem); + return; + } + } + VMA_ASSERT(0 && "Not found!"); +} + +bool VmaBlockMetadata_Generic::ValidateFreeSuballocationList() const +{ + VkDeviceSize lastSize = 0; + for(size_t i = 0, count = m_FreeSuballocationsBySize.size(); i < count; ++i) + { + const VmaSuballocationList::iterator it = m_FreeSuballocationsBySize[i]; + + VMA_VALIDATE(it->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_VALIDATE(it->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER); + VMA_VALIDATE(it->size >= lastSize); + lastSize = it->size; + } + return true; +} + +bool VmaBlockMetadata_Generic::CheckAllocation( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + VmaSuballocationList::const_iterator suballocItem, + bool canMakeOtherLost, + VkDeviceSize* pOffset, + size_t* itemsToMakeLostCount, + VkDeviceSize* pSumFreeSize, + VkDeviceSize* pSumItemSize) const +{ + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(suballocItem != m_Suballocations.cend()); + VMA_ASSERT(pOffset != VMA_NULL); + + *itemsToMakeLostCount = 0; + *pSumFreeSize = 0; + *pSumItemSize = 0; + + if(canMakeOtherLost) + { + if(suballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) + { + *pSumFreeSize = suballocItem->size; + } + else + { + if(suballocItem->hAllocation->CanBecomeLost() && + suballocItem->hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) + { + ++*itemsToMakeLostCount; + *pSumItemSize = suballocItem->size; + } + else + { + return false; + } + } + + // Remaining size is too small for this request: Early return. + if(GetSize() - suballocItem->offset < allocSize) + { + return false; + } + + // Start from offset equal to beginning of this suballocation. + *pOffset = suballocItem->offset; + + // Apply VMA_DEBUG_MARGIN at the beginning. + if(VMA_DEBUG_MARGIN > 0) + { + *pOffset += VMA_DEBUG_MARGIN; + } + + // Apply alignment. + *pOffset = VmaAlignUp(*pOffset, allocAlignment); + + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if(bufferImageGranularity > 1) + { + bool bufferImageGranularityConflict = false; + VmaSuballocationList::const_iterator prevSuballocItem = suballocItem; + while(prevSuballocItem != m_Suballocations.cbegin()) + { + --prevSuballocItem; + const VmaSuballocation& prevSuballoc = *prevSuballocItem; + if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, *pOffset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if(bufferImageGranularityConflict) + { + *pOffset = VmaAlignUp(*pOffset, bufferImageGranularity); + } + } + + // Now that we have final *pOffset, check if we are past suballocItem. + // If yes, return false - this function should be called for another suballocItem as starting point. + if(*pOffset >= suballocItem->offset + suballocItem->size) + { + return false; + } + + // Calculate padding at the beginning based on current offset. + const VkDeviceSize paddingBegin = *pOffset - suballocItem->offset; + + // Calculate required margin at the end. + const VkDeviceSize requiredEndMargin = VMA_DEBUG_MARGIN; + + const VkDeviceSize totalSize = paddingBegin + allocSize + requiredEndMargin; + // Another early return check. + if(suballocItem->offset + totalSize > GetSize()) + { + return false; + } + + // Advance lastSuballocItem until desired size is reached. + // Update itemsToMakeLostCount. + VmaSuballocationList::const_iterator lastSuballocItem = suballocItem; + if(totalSize > suballocItem->size) + { + VkDeviceSize remainingSize = totalSize - suballocItem->size; + while(remainingSize > 0) + { + ++lastSuballocItem; + if(lastSuballocItem == m_Suballocations.cend()) + { + return false; + } + if(lastSuballocItem->type == VMA_SUBALLOCATION_TYPE_FREE) + { + *pSumFreeSize += lastSuballocItem->size; + } + else + { + VMA_ASSERT(lastSuballocItem->hAllocation != VK_NULL_HANDLE); + if(lastSuballocItem->hAllocation->CanBecomeLost() && + lastSuballocItem->hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) + { + ++*itemsToMakeLostCount; + *pSumItemSize += lastSuballocItem->size; + } + else + { + return false; + } + } + remainingSize = (lastSuballocItem->size < remainingSize) ? + remainingSize - lastSuballocItem->size : 0; + } + } + + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, we must mark more allocations lost or fail. + if(bufferImageGranularity > 1) + { + VmaSuballocationList::const_iterator nextSuballocItem = lastSuballocItem; + ++nextSuballocItem; + while(nextSuballocItem != m_Suballocations.cend()) + { + const VmaSuballocation& nextSuballoc = *nextSuballocItem; + if(VmaBlocksOnSamePage(*pOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + VMA_ASSERT(nextSuballoc.hAllocation != VK_NULL_HANDLE); + if(nextSuballoc.hAllocation->CanBecomeLost() && + nextSuballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) + { + ++*itemsToMakeLostCount; + } + else + { + return false; + } + } + } + else + { + // Already on next page. + break; + } + ++nextSuballocItem; + } + } + } + else + { + const VmaSuballocation& suballoc = *suballocItem; + VMA_ASSERT(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + *pSumFreeSize = suballoc.size; + + // Size of this suballocation is too small for this request: Early return. + if(suballoc.size < allocSize) + { + return false; + } + + // Start from offset equal to beginning of this suballocation. + *pOffset = suballoc.offset; + + // Apply VMA_DEBUG_MARGIN at the beginning. + if(VMA_DEBUG_MARGIN > 0) + { + *pOffset += VMA_DEBUG_MARGIN; + } + + // Apply alignment. + *pOffset = VmaAlignUp(*pOffset, allocAlignment); + + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if(bufferImageGranularity > 1) + { + bool bufferImageGranularityConflict = false; + VmaSuballocationList::const_iterator prevSuballocItem = suballocItem; + while(prevSuballocItem != m_Suballocations.cbegin()) + { + --prevSuballocItem; + const VmaSuballocation& prevSuballoc = *prevSuballocItem; + if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, *pOffset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if(bufferImageGranularityConflict) + { + *pOffset = VmaAlignUp(*pOffset, bufferImageGranularity); + } + } + + // Calculate padding at the beginning based on current offset. + const VkDeviceSize paddingBegin = *pOffset - suballoc.offset; + + // Calculate required margin at the end. + const VkDeviceSize requiredEndMargin = VMA_DEBUG_MARGIN; + + // Fail if requested size plus margin before and after is bigger than size of this suballocation. + if(paddingBegin + allocSize + requiredEndMargin > suballoc.size) + { + return false; + } + + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if(bufferImageGranularity > 1) + { + VmaSuballocationList::const_iterator nextSuballocItem = suballocItem; + ++nextSuballocItem; + while(nextSuballocItem != m_Suballocations.cend()) + { + const VmaSuballocation& nextSuballoc = *nextSuballocItem; + if(VmaBlocksOnSamePage(*pOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } + } + else + { + // Already on next page. + break; + } + ++nextSuballocItem; + } + } + } + + // All tests passed: Success. pOffset is already filled. + return true; +} + +void VmaBlockMetadata_Generic::MergeFreeWithNext(VmaSuballocationList::iterator item) +{ + VMA_ASSERT(item != m_Suballocations.end()); + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + + VmaSuballocationList::iterator nextItem = item; + ++nextItem; + VMA_ASSERT(nextItem != m_Suballocations.end()); + VMA_ASSERT(nextItem->type == VMA_SUBALLOCATION_TYPE_FREE); + + item->size += nextItem->size; + --m_FreeCount; + m_Suballocations.erase(nextItem); +} + +VmaSuballocationList::iterator VmaBlockMetadata_Generic::FreeSuballocation(VmaSuballocationList::iterator suballocItem) +{ + // Change this suballocation to be marked as free. + VmaSuballocation& suballoc = *suballocItem; + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.hAllocation = VK_NULL_HANDLE; + + // Update totals. + ++m_FreeCount; + m_SumFreeSize += suballoc.size; + + // Merge with previous and/or next suballocation if it's also free. + bool mergeWithNext = false; + bool mergeWithPrev = false; + + VmaSuballocationList::iterator nextItem = suballocItem; + ++nextItem; + if((nextItem != m_Suballocations.end()) && (nextItem->type == VMA_SUBALLOCATION_TYPE_FREE)) + { + mergeWithNext = true; + } + + VmaSuballocationList::iterator prevItem = suballocItem; + if(suballocItem != m_Suballocations.begin()) + { + --prevItem; + if(prevItem->type == VMA_SUBALLOCATION_TYPE_FREE) + { + mergeWithPrev = true; + } + } + + if(mergeWithNext) + { + UnregisterFreeSuballocation(nextItem); + MergeFreeWithNext(suballocItem); + } + + if(mergeWithPrev) + { + UnregisterFreeSuballocation(prevItem); + MergeFreeWithNext(prevItem); + RegisterFreeSuballocation(prevItem); + return prevItem; + } + else + { + RegisterFreeSuballocation(suballocItem); + return suballocItem; + } +} + +void VmaBlockMetadata_Generic::RegisterFreeSuballocation(VmaSuballocationList::iterator item) +{ + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(item->size > 0); + + // You may want to enable this validation at the beginning or at the end of + // this function, depending on what do you want to check. + VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); + + if(item->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + if(m_FreeSuballocationsBySize.empty()) + { + m_FreeSuballocationsBySize.push_back(item); + } + else + { + VmaVectorInsertSorted(m_FreeSuballocationsBySize, item); + } + } + + //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); +} + + +void VmaBlockMetadata_Generic::UnregisterFreeSuballocation(VmaSuballocationList::iterator item) +{ + VMA_ASSERT(item->type == VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(item->size > 0); + + // You may want to enable this validation at the beginning or at the end of + // this function, depending on what do you want to check. + VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); + + if(item->size >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + VmaSuballocationList::iterator* const it = VmaBinaryFindFirstNotLess( + m_FreeSuballocationsBySize.data(), + m_FreeSuballocationsBySize.data() + m_FreeSuballocationsBySize.size(), + item, + VmaSuballocationItemSizeLess()); + for(size_t index = it - m_FreeSuballocationsBySize.data(); + index < m_FreeSuballocationsBySize.size(); + ++index) + { + if(m_FreeSuballocationsBySize[index] == item) + { + VmaVectorRemove(m_FreeSuballocationsBySize, index); + return; + } + VMA_ASSERT((m_FreeSuballocationsBySize[index]->size == item->size) && "Not found."); + } + VMA_ASSERT(0 && "Not found."); + } + + //VMA_HEAVY_ASSERT(ValidateFreeSuballocationList()); +} + +bool VmaBlockMetadata_Generic::IsBufferImageGranularityConflictPossible( + VkDeviceSize bufferImageGranularity, + VmaSuballocationType& inOutPrevSuballocType) const +{ + if(bufferImageGranularity == 1 || IsEmpty()) + { + return false; + } + + VkDeviceSize minAlignment = VK_WHOLE_SIZE; + bool typeConflictFound = false; + for(VmaSuballocationList::const_iterator it = m_Suballocations.cbegin(); + it != m_Suballocations.cend(); + ++it) + { + const VmaSuballocationType suballocType = it->type; + if(suballocType != VMA_SUBALLOCATION_TYPE_FREE) + { + minAlignment = VMA_MIN(minAlignment, it->hAllocation->GetAlignment()); + if(VmaIsBufferImageGranularityConflict(inOutPrevSuballocType, suballocType)) + { + typeConflictFound = true; + } + inOutPrevSuballocType = suballocType; + } + } + + return typeConflictFound || minAlignment >= bufferImageGranularity; +} + +//////////////////////////////////////////////////////////////////////////////// +// class VmaBlockMetadata_Linear + +VmaBlockMetadata_Linear::VmaBlockMetadata_Linear(VmaAllocator hAllocator) : + VmaBlockMetadata(hAllocator), + m_SumFreeSize(0), + m_Suballocations0(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_Suballocations1(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_1stVectorIndex(0), + m_2ndVectorMode(SECOND_VECTOR_EMPTY), + m_1stNullItemsBeginCount(0), + m_1stNullItemsMiddleCount(0), + m_2ndNullItemsCount(0) +{ +} + +VmaBlockMetadata_Linear::~VmaBlockMetadata_Linear() +{ +} + +void VmaBlockMetadata_Linear::Init(VkDeviceSize size) +{ + VmaBlockMetadata::Init(size); + m_SumFreeSize = size; +} + +bool VmaBlockMetadata_Linear::Validate() const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + VMA_VALIDATE(suballocations2nd.empty() == (m_2ndVectorMode == SECOND_VECTOR_EMPTY)); + VMA_VALIDATE(!suballocations1st.empty() || + suballocations2nd.empty() || + m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER); + + if(!suballocations1st.empty()) + { + // Null item at the beginning should be accounted into m_1stNullItemsBeginCount. + VMA_VALIDATE(suballocations1st[m_1stNullItemsBeginCount].hAllocation != VK_NULL_HANDLE); + // Null item at the end should be just pop_back(). + VMA_VALIDATE(suballocations1st.back().hAllocation != VK_NULL_HANDLE); + } + if(!suballocations2nd.empty()) + { + // Null item at the end should be just pop_back(). + VMA_VALIDATE(suballocations2nd.back().hAllocation != VK_NULL_HANDLE); + } + + VMA_VALIDATE(m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount <= suballocations1st.size()); + VMA_VALIDATE(m_2ndNullItemsCount <= suballocations2nd.size()); + + VkDeviceSize sumUsedSize = 0; + const size_t suballoc1stCount = suballocations1st.size(); + VkDeviceSize offset = VMA_DEBUG_MARGIN; + + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const size_t suballoc2ndCount = suballocations2nd.size(); + size_t nullItem2ndCount = 0; + for(size_t i = 0; i < suballoc2ndCount; ++i) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); + VMA_VALIDATE(suballoc.offset >= offset); + + if(!currFree) + { + VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); + VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); + sumUsedSize += suballoc.size; + } + else + { + ++nullItem2ndCount; + } + + offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; + } + + VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); + } + + for(size_t i = 0; i < m_1stNullItemsBeginCount; ++i) + { + const VmaSuballocation& suballoc = suballocations1st[i]; + VMA_VALIDATE(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE && + suballoc.hAllocation == VK_NULL_HANDLE); + } + + size_t nullItem1stCount = m_1stNullItemsBeginCount; + + for(size_t i = m_1stNullItemsBeginCount; i < suballoc1stCount; ++i) + { + const VmaSuballocation& suballoc = suballocations1st[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); + VMA_VALIDATE(suballoc.offset >= offset); + VMA_VALIDATE(i >= m_1stNullItemsBeginCount || currFree); + + if(!currFree) + { + VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); + VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); + sumUsedSize += suballoc.size; + } + else + { + ++nullItem1stCount; + } + + offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; + } + VMA_VALIDATE(nullItem1stCount == m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount); + + if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + const size_t suballoc2ndCount = suballocations2nd.size(); + size_t nullItem2ndCount = 0; + for(size_t i = suballoc2ndCount; i--; ) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + const bool currFree = (suballoc.type == VMA_SUBALLOCATION_TYPE_FREE); + + VMA_VALIDATE(currFree == (suballoc.hAllocation == VK_NULL_HANDLE)); + VMA_VALIDATE(suballoc.offset >= offset); + + if(!currFree) + { + VMA_VALIDATE(suballoc.hAllocation->GetOffset() == suballoc.offset); + VMA_VALIDATE(suballoc.hAllocation->GetSize() == suballoc.size); + sumUsedSize += suballoc.size; + } + else + { + ++nullItem2ndCount; + } + + offset = suballoc.offset + suballoc.size + VMA_DEBUG_MARGIN; + } + + VMA_VALIDATE(nullItem2ndCount == m_2ndNullItemsCount); + } + + VMA_VALIDATE(offset <= GetSize()); + VMA_VALIDATE(m_SumFreeSize == GetSize() - sumUsedSize); + + return true; +} + +size_t VmaBlockMetadata_Linear::GetAllocationCount() const +{ + return AccessSuballocations1st().size() - (m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount) + + AccessSuballocations2nd().size() - m_2ndNullItemsCount; +} + +VkDeviceSize VmaBlockMetadata_Linear::GetUnusedRangeSizeMax() const +{ + const VkDeviceSize size = GetSize(); + + /* + We don't consider gaps inside allocation vectors with freed allocations because + they are not suitable for reuse in linear allocator. We consider only space that + is available for new allocations. + */ + if(IsEmpty()) + { + return size; + } + + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + + switch(m_2ndVectorMode) + { + case SECOND_VECTOR_EMPTY: + /* + Available space is after end of 1st, as well as before beginning of 1st (which + whould make it a ring buffer). + */ + { + const size_t suballocations1stCount = suballocations1st.size(); + VMA_ASSERT(suballocations1stCount > m_1stNullItemsBeginCount); + const VmaSuballocation& firstSuballoc = suballocations1st[m_1stNullItemsBeginCount]; + const VmaSuballocation& lastSuballoc = suballocations1st[suballocations1stCount - 1]; + return VMA_MAX( + firstSuballoc.offset, + size - (lastSuballoc.offset + lastSuballoc.size)); + } + break; + + case SECOND_VECTOR_RING_BUFFER: + /* + Available space is only between end of 2nd and beginning of 1st. + */ + { + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const VmaSuballocation& lastSuballoc2nd = suballocations2nd.back(); + const VmaSuballocation& firstSuballoc1st = suballocations1st[m_1stNullItemsBeginCount]; + return firstSuballoc1st.offset - (lastSuballoc2nd.offset + lastSuballoc2nd.size); + } + break; + + case SECOND_VECTOR_DOUBLE_STACK: + /* + Available space is only between end of 1st and top of 2nd. + */ + { + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const VmaSuballocation& topSuballoc2nd = suballocations2nd.back(); + const VmaSuballocation& lastSuballoc1st = suballocations1st.back(); + return topSuballoc2nd.offset - (lastSuballoc1st.offset + lastSuballoc1st.size); + } + break; + + default: + VMA_ASSERT(0); + return 0; + } +} + +void VmaBlockMetadata_Linear::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +{ + const VkDeviceSize size = GetSize(); + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); + + outInfo.blockCount = 1; + outInfo.allocationCount = (uint32_t)GetAllocationCount(); + outInfo.unusedRangeCount = 0; + outInfo.usedBytes = 0; + outInfo.allocationSizeMin = UINT64_MAX; + outInfo.allocationSizeMax = 0; + outInfo.unusedRangeSizeMin = UINT64_MAX; + outInfo.unusedRangeSizeMax = 0; + + VkDeviceSize lastOffset = 0; + + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while(lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + outInfo.usedBytes += suballoc.size; + outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); + outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + if(lastOffset < freeSpace2ndTo1stEnd) + { + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } + + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while(lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc1stIndex; + } + + // Found non-null allocation. + if(nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + outInfo.usedBytes += suballoc.size; + outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); + outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + if(lastOffset < freeSpace1stTo2ndEnd) + { + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } + + if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while(lastOffset < size) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + --nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + outInfo.usedBytes += suballoc.size; + outInfo.allocationSizeMin = VMA_MIN(outInfo.allocationSizeMin, suballoc.size); + outInfo.allocationSizeMax = VMA_MIN(outInfo.allocationSizeMax, suballoc.size); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + // There is free space from lastOffset to size. + if(lastOffset < size) + { + const VkDeviceSize unusedRangeSize = size - lastOffset; + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusedRangeSize); + outInfo.unusedRangeSizeMax = VMA_MIN(outInfo.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = size; + } + } + } + + outInfo.unusedBytes = size - outInfo.usedBytes; +} + +void VmaBlockMetadata_Linear::AddPoolStats(VmaPoolStats& inoutStats) const +{ + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const VkDeviceSize size = GetSize(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); + + inoutStats.size += size; + + VkDeviceSize lastOffset = 0; + + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = m_1stNullItemsBeginCount; + while(lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } + + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while(lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc1stIndex; + } + + // Found non-null allocation. + if(nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if(lastOffset < freeSpace1stTo2ndEnd) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } + + if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while(lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + --nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++inoutStats.allocationCount; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < size) + { + // There is free space from lastOffset to size. + const VkDeviceSize unusedRangeSize = size - lastOffset; + inoutStats.unusedSize += unusedRangeSize; + ++inoutStats.unusedRangeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, unusedRangeSize); + } + + // End of loop. + lastOffset = size; + } + } + } +} + +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_Linear::PrintDetailedMap(class VmaJsonWriter& json) const +{ + const VkDeviceSize size = GetSize(); + const SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + const SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + const size_t suballoc1stCount = suballocations1st.size(); + const size_t suballoc2ndCount = suballocations2nd.size(); + + // FIRST PASS + + size_t unusedRangeCount = 0; + VkDeviceSize usedBytes = 0; + + VkDeviceSize lastOffset = 0; + + size_t alloc2ndCount = 0; + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while(lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc2ndCount; + usedBytes += suballoc.size; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + ++unusedRangeCount; + } + + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } + + size_t nextAlloc1stIndex = m_1stNullItemsBeginCount; + size_t alloc1stCount = 0; + const VkDeviceSize freeSpace1stTo2ndEnd = + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? suballocations2nd.back().offset : size; + while(lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc1stIndex; + } + + // Found non-null allocation. + if(nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc1stCount; + usedBytes += suballoc.size; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if(lastOffset < size) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + ++unusedRangeCount; + } + + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } + + if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while(lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + --nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + ++unusedRangeCount; + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + ++alloc2ndCount; + usedBytes += suballoc.size; + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < size) + { + // There is free space from lastOffset to size. + ++unusedRangeCount; + } + + // End of loop. + lastOffset = size; + } + } + } + + const VkDeviceSize unusedBytes = size - usedBytes; + PrintDetailedMap_Begin(json, unusedBytes, alloc1stCount + alloc2ndCount, unusedRangeCount); + + // SECOND PASS + lastOffset = 0; + + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + const VkDeviceSize freeSpace2ndTo1stEnd = suballocations1st[m_1stNullItemsBeginCount].offset; + size_t nextAlloc2ndIndex = 0; + while(lastOffset < freeSpace2ndTo1stEnd) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex < suballoc2ndCount && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex < suballoc2ndCount) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < freeSpace2ndTo1stEnd) + { + // There is free space from lastOffset to freeSpace2ndTo1stEnd. + const VkDeviceSize unusedRangeSize = freeSpace2ndTo1stEnd - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace2ndTo1stEnd; + } + } + } + + nextAlloc1stIndex = m_1stNullItemsBeginCount; + while(lastOffset < freeSpace1stTo2ndEnd) + { + // Find next non-null allocation or move nextAllocIndex to the end. + while(nextAlloc1stIndex < suballoc1stCount && + suballocations1st[nextAlloc1stIndex].hAllocation == VK_NULL_HANDLE) + { + ++nextAlloc1stIndex; + } + + // Found non-null allocation. + if(nextAlloc1stIndex < suballoc1stCount) + { + const VmaSuballocation& suballoc = suballocations1st[nextAlloc1stIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + ++nextAlloc1stIndex; + } + // We are at the end. + else + { + if(lastOffset < freeSpace1stTo2ndEnd) + { + // There is free space from lastOffset to freeSpace1stTo2ndEnd. + const VkDeviceSize unusedRangeSize = freeSpace1stTo2ndEnd - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // End of loop. + lastOffset = freeSpace1stTo2ndEnd; + } + } + + if(m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + size_t nextAlloc2ndIndex = suballocations2nd.size() - 1; + while(lastOffset < size) + { + // Find next non-null allocation or move nextAlloc2ndIndex to the end. + while(nextAlloc2ndIndex != SIZE_MAX && + suballocations2nd[nextAlloc2ndIndex].hAllocation == VK_NULL_HANDLE) + { + --nextAlloc2ndIndex; + } + + // Found non-null allocation. + if(nextAlloc2ndIndex != SIZE_MAX) + { + const VmaSuballocation& suballoc = suballocations2nd[nextAlloc2ndIndex]; + + // 1. Process free space before this allocation. + if(lastOffset < suballoc.offset) + { + // There is free space from lastOffset to suballoc.offset. + const VkDeviceSize unusedRangeSize = suballoc.offset - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // 2. Process this allocation. + // There is allocation with suballoc.offset, suballoc.size. + PrintDetailedMap_Allocation(json, suballoc.offset, suballoc.hAllocation); + + // 3. Prepare for next iteration. + lastOffset = suballoc.offset + suballoc.size; + --nextAlloc2ndIndex; + } + // We are at the end. + else + { + if(lastOffset < size) + { + // There is free space from lastOffset to size. + const VkDeviceSize unusedRangeSize = size - lastOffset; + PrintDetailedMap_UnusedRange(json, lastOffset, unusedRangeSize); + } + + // End of loop. + lastOffset = size; + } + } + } + + PrintDetailedMap_End(json); +} +#endif // #if VMA_STATS_STRING_ENABLED + +bool VmaBlockMetadata_Linear::CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + VMA_ASSERT(allocSize > 0); + VMA_ASSERT(allocType != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(pAllocationRequest != VMA_NULL); + VMA_HEAVY_ASSERT(Validate()); + return upperAddress ? + CreateAllocationRequest_UpperAddress( + currentFrameIndex, frameInUseCount, bufferImageGranularity, + allocSize, allocAlignment, allocType, canMakeOtherLost, strategy, pAllocationRequest) : + CreateAllocationRequest_LowerAddress( + currentFrameIndex, frameInUseCount, bufferImageGranularity, + allocSize, allocAlignment, allocType, canMakeOtherLost, strategy, pAllocationRequest); +} + +bool VmaBlockMetadata_Linear::CreateAllocationRequest_UpperAddress( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + const VkDeviceSize size = GetSize(); + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + VMA_ASSERT(0 && "Trying to use pool with linear algorithm as double stack, while it is already being used as ring buffer."); + return false; + } + + // Try to allocate before 2nd.back(), or end of block if 2nd.empty(). + if(allocSize > size) + { + return false; + } + VkDeviceSize resultBaseOffset = size - allocSize; + if(!suballocations2nd.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations2nd.back(); + resultBaseOffset = lastSuballoc.offset - allocSize; + if(allocSize > lastSuballoc.offset) + { + return false; + } + } + + // Start from offset equal to end of free space. + VkDeviceSize resultOffset = resultBaseOffset; + + // Apply VMA_DEBUG_MARGIN at the end. + if(VMA_DEBUG_MARGIN > 0) + { + if(resultOffset < VMA_DEBUG_MARGIN) + { + return false; + } + resultOffset -= VMA_DEBUG_MARGIN; + } + + // Apply alignment. + resultOffset = VmaAlignDown(resultOffset, allocAlignment); + + // Check next suballocations from 2nd for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if(bufferImageGranularity > 1 && !suballocations2nd.empty()) + { + bool bufferImageGranularityConflict = false; + for(size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) + { + const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; + if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(nextSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if(bufferImageGranularityConflict) + { + resultOffset = VmaAlignDown(resultOffset, bufferImageGranularity); + } + } + + // There is enough free space. + const VkDeviceSize endOf1st = !suballocations1st.empty() ? + suballocations1st.back().offset + suballocations1st.back().size : + 0; + if(endOf1st + VMA_DEBUG_MARGIN <= resultOffset) + { + // Check previous suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if(bufferImageGranularity > 1) + { + for(size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; + if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(allocType, prevSuballoc.type)) + { + return false; + } + } + else + { + // Already on next page. + break; + } + } + } + + // All tests passed: Success. + pAllocationRequest->offset = resultOffset; + pAllocationRequest->sumFreeSize = resultBaseOffset + allocSize - endOf1st; + pAllocationRequest->sumItemSize = 0; + // pAllocationRequest->item unused. + pAllocationRequest->itemsToMakeLostCount = 0; + pAllocationRequest->type = VmaAllocationRequestType::UpperAddress; + return true; + } + + return false; +} + +bool VmaBlockMetadata_Linear::CreateAllocationRequest_LowerAddress( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + const VkDeviceSize size = GetSize(); + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + if(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + // Try to allocate at the end of 1st vector. + + VkDeviceSize resultBaseOffset = 0; + if(!suballocations1st.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations1st.back(); + resultBaseOffset = lastSuballoc.offset + lastSuballoc.size; + } + + // Start from offset equal to beginning of free space. + VkDeviceSize resultOffset = resultBaseOffset; + + // Apply VMA_DEBUG_MARGIN at the beginning. + if(VMA_DEBUG_MARGIN > 0) + { + resultOffset += VMA_DEBUG_MARGIN; + } + + // Apply alignment. + resultOffset = VmaAlignUp(resultOffset, allocAlignment); + + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if(bufferImageGranularity > 1 && !suballocations1st.empty()) + { + bool bufferImageGranularityConflict = false; + for(size_t prevSuballocIndex = suballocations1st.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations1st[prevSuballocIndex]; + if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if(bufferImageGranularityConflict) + { + resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + } + } + + const VkDeviceSize freeSpaceEnd = m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK ? + suballocations2nd.back().offset : size; + + // There is enough free space at the end after alignment. + if(resultOffset + allocSize + VMA_DEBUG_MARGIN <= freeSpaceEnd) + { + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if(bufferImageGranularity > 1 && m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + for(size_t nextSuballocIndex = suballocations2nd.size(); nextSuballocIndex--; ) + { + const VmaSuballocation& nextSuballoc = suballocations2nd[nextSuballocIndex]; + if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } + } + else + { + // Already on previous page. + break; + } + } + } + + // All tests passed: Success. + pAllocationRequest->offset = resultOffset; + pAllocationRequest->sumFreeSize = freeSpaceEnd - resultBaseOffset; + pAllocationRequest->sumItemSize = 0; + // pAllocationRequest->item, customData unused. + pAllocationRequest->type = VmaAllocationRequestType::EndOf1st; + pAllocationRequest->itemsToMakeLostCount = 0; + return true; + } + } + + // Wrap-around to end of 2nd vector. Try to allocate there, watching for the + // beginning of 1st vector as the end of free space. + if(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + VMA_ASSERT(!suballocations1st.empty()); + + VkDeviceSize resultBaseOffset = 0; + if(!suballocations2nd.empty()) + { + const VmaSuballocation& lastSuballoc = suballocations2nd.back(); + resultBaseOffset = lastSuballoc.offset + lastSuballoc.size; + } + + // Start from offset equal to beginning of free space. + VkDeviceSize resultOffset = resultBaseOffset; + + // Apply VMA_DEBUG_MARGIN at the beginning. + if(VMA_DEBUG_MARGIN > 0) + { + resultOffset += VMA_DEBUG_MARGIN; + } + + // Apply alignment. + resultOffset = VmaAlignUp(resultOffset, allocAlignment); + + // Check previous suballocations for BufferImageGranularity conflicts. + // Make bigger alignment if necessary. + if(bufferImageGranularity > 1 && !suballocations2nd.empty()) + { + bool bufferImageGranularityConflict = false; + for(size_t prevSuballocIndex = suballocations2nd.size(); prevSuballocIndex--; ) + { + const VmaSuballocation& prevSuballoc = suballocations2nd[prevSuballocIndex]; + if(VmaBlocksOnSamePage(prevSuballoc.offset, prevSuballoc.size, resultOffset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(prevSuballoc.type, allocType)) + { + bufferImageGranularityConflict = true; + break; + } + } + else + // Already on previous page. + break; + } + if(bufferImageGranularityConflict) + { + resultOffset = VmaAlignUp(resultOffset, bufferImageGranularity); + } + } + + pAllocationRequest->itemsToMakeLostCount = 0; + pAllocationRequest->sumItemSize = 0; + size_t index1st = m_1stNullItemsBeginCount; + + if(canMakeOtherLost) + { + while(index1st < suballocations1st.size() && + resultOffset + allocSize + VMA_DEBUG_MARGIN > suballocations1st[index1st].offset) + { + // Next colliding allocation at the beginning of 1st vector found. Try to make it lost. + const VmaSuballocation& suballoc = suballocations1st[index1st]; + if(suballoc.type == VMA_SUBALLOCATION_TYPE_FREE) + { + // No problem. + } + else + { + VMA_ASSERT(suballoc.hAllocation != VK_NULL_HANDLE); + if(suballoc.hAllocation->CanBecomeLost() && + suballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) + { + ++pAllocationRequest->itemsToMakeLostCount; + pAllocationRequest->sumItemSize += suballoc.size; + } + else + { + return false; + } + } + ++index1st; + } + + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, we must mark more allocations lost or fail. + if(bufferImageGranularity > 1) + { + while(index1st < suballocations1st.size()) + { + const VmaSuballocation& suballoc = suballocations1st[index1st]; + if(VmaBlocksOnSamePage(resultOffset, allocSize, suballoc.offset, bufferImageGranularity)) + { + if(suballoc.hAllocation != VK_NULL_HANDLE) + { + // Not checking actual VmaIsBufferImageGranularityConflict(allocType, suballoc.type). + if(suballoc.hAllocation->CanBecomeLost() && + suballoc.hAllocation->GetLastUseFrameIndex() + frameInUseCount < currentFrameIndex) + { + ++pAllocationRequest->itemsToMakeLostCount; + pAllocationRequest->sumItemSize += suballoc.size; + } + else + { + return false; + } + } + } + else + { + // Already on next page. + break; + } + ++index1st; + } + } + + // Special case: There is not enough room at the end for this allocation, even after making all from the 1st lost. + if(index1st == suballocations1st.size() && + resultOffset + allocSize + VMA_DEBUG_MARGIN > size) + { + // TODO: This is a known bug that it's not yet implemented and the allocation is failing. + VMA_DEBUG_LOG("Unsupported special case in custom pool with linear allocation algorithm used as ring buffer with allocations that can be lost."); + } + } + + // There is enough free space at the end after alignment. + if((index1st == suballocations1st.size() && resultOffset + allocSize + VMA_DEBUG_MARGIN <= size) || + (index1st < suballocations1st.size() && resultOffset + allocSize + VMA_DEBUG_MARGIN <= suballocations1st[index1st].offset)) + { + // Check next suballocations for BufferImageGranularity conflicts. + // If conflict exists, allocation cannot be made here. + if(bufferImageGranularity > 1) + { + for(size_t nextSuballocIndex = index1st; + nextSuballocIndex < suballocations1st.size(); + nextSuballocIndex++) + { + const VmaSuballocation& nextSuballoc = suballocations1st[nextSuballocIndex]; + if(VmaBlocksOnSamePage(resultOffset, allocSize, nextSuballoc.offset, bufferImageGranularity)) + { + if(VmaIsBufferImageGranularityConflict(allocType, nextSuballoc.type)) + { + return false; + } + } + else + { + // Already on next page. + break; + } + } + } + + // All tests passed: Success. + pAllocationRequest->offset = resultOffset; + pAllocationRequest->sumFreeSize = + (index1st < suballocations1st.size() ? suballocations1st[index1st].offset : size) + - resultBaseOffset + - pAllocationRequest->sumItemSize; + pAllocationRequest->type = VmaAllocationRequestType::EndOf2nd; + // pAllocationRequest->item, customData unused. + return true; + } + } + + return false; +} + +bool VmaBlockMetadata_Linear::MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest) +{ + if(pAllocationRequest->itemsToMakeLostCount == 0) + { + return true; + } + + VMA_ASSERT(m_2ndVectorMode == SECOND_VECTOR_EMPTY || m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER); + + // We always start from 1st. + SuballocationVectorType* suballocations = &AccessSuballocations1st(); + size_t index = m_1stNullItemsBeginCount; + size_t madeLostCount = 0; + while(madeLostCount < pAllocationRequest->itemsToMakeLostCount) + { + if(index == suballocations->size()) + { + index = 0; + // If we get to the end of 1st, we wrap around to beginning of 2nd of 1st. + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + suballocations = &AccessSuballocations2nd(); + } + // else: m_2ndVectorMode == SECOND_VECTOR_EMPTY: + // suballocations continues pointing at AccessSuballocations1st(). + VMA_ASSERT(!suballocations->empty()); + } + VmaSuballocation& suballoc = (*suballocations)[index]; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + VMA_ASSERT(suballoc.hAllocation != VK_NULL_HANDLE); + VMA_ASSERT(suballoc.hAllocation->CanBecomeLost()); + if(suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + { + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.hAllocation = VK_NULL_HANDLE; + m_SumFreeSize += suballoc.size; + if(suballocations == &AccessSuballocations1st()) + { + ++m_1stNullItemsMiddleCount; + } + else + { + ++m_2ndNullItemsCount; + } + ++madeLostCount; + } + else + { + return false; + } + } + ++index; + } + + CleanupAfterFree(); + //VMA_HEAVY_ASSERT(Validate()); // Already called by ClanupAfterFree(). + + return true; +} + +uint32_t VmaBlockMetadata_Linear::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) +{ + uint32_t lostAllocationCount = 0; + + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + for(size_t i = m_1stNullItemsBeginCount, count = suballocations1st.size(); i < count; ++i) + { + VmaSuballocation& suballoc = suballocations1st[i]; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE && + suballoc.hAllocation->CanBecomeLost() && + suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + { + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.hAllocation = VK_NULL_HANDLE; + ++m_1stNullItemsMiddleCount; + m_SumFreeSize += suballoc.size; + ++lostAllocationCount; + } + } + + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + for(size_t i = 0, count = suballocations2nd.size(); i < count; ++i) + { + VmaSuballocation& suballoc = suballocations2nd[i]; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE && + suballoc.hAllocation->CanBecomeLost() && + suballoc.hAllocation->MakeLost(currentFrameIndex, frameInUseCount)) + { + suballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + suballoc.hAllocation = VK_NULL_HANDLE; + ++m_2ndNullItemsCount; + m_SumFreeSize += suballoc.size; + ++lostAllocationCount; + } + } + + if(lostAllocationCount) + { + CleanupAfterFree(); + } + + return lostAllocationCount; +} + +VkResult VmaBlockMetadata_Linear::CheckCorruption(const void* pBlockData) +{ + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + for(size_t i = m_1stNullItemsBeginCount, count = suballocations1st.size(); i < count; ++i) + { + const VmaSuballocation& suballoc = suballocations1st[i]; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + if(!VmaValidateMagicValue(pBlockData, suballoc.offset - VMA_DEBUG_MARGIN)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if(!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + } + } + + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + for(size_t i = 0, count = suballocations2nd.size(); i < count; ++i) + { + const VmaSuballocation& suballoc = suballocations2nd[i]; + if(suballoc.type != VMA_SUBALLOCATION_TYPE_FREE) + { + if(!VmaValidateMagicValue(pBlockData, suballoc.offset - VMA_DEBUG_MARGIN)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if(!VmaValidateMagicValue(pBlockData, suballoc.offset + suballoc.size)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER VALIDATED ALLOCATION!"); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + } + } + + return VK_SUCCESS; +} + +void VmaBlockMetadata_Linear::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation) +{ + const VmaSuballocation newSuballoc = { request.offset, allocSize, hAllocation, type }; + + switch(request.type) + { + case VmaAllocationRequestType::UpperAddress: + { + VMA_ASSERT(m_2ndVectorMode != SECOND_VECTOR_RING_BUFFER && + "CRITICAL ERROR: Trying to use linear allocator as double stack while it was already used as ring buffer."); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + suballocations2nd.push_back(newSuballoc); + m_2ndVectorMode = SECOND_VECTOR_DOUBLE_STACK; + } + break; + case VmaAllocationRequestType::EndOf1st: + { + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + + VMA_ASSERT(suballocations1st.empty() || + request.offset >= suballocations1st.back().offset + suballocations1st.back().size); + // Check if it fits before the end of the block. + VMA_ASSERT(request.offset + allocSize <= GetSize()); + + suballocations1st.push_back(newSuballoc); + } + break; + case VmaAllocationRequestType::EndOf2nd: + { + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + // New allocation at the end of 2-part ring buffer, so before first allocation from 1st vector. + VMA_ASSERT(!suballocations1st.empty() && + request.offset + allocSize <= suballocations1st[m_1stNullItemsBeginCount].offset); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + switch(m_2ndVectorMode) + { + case SECOND_VECTOR_EMPTY: + // First allocation from second part ring buffer. + VMA_ASSERT(suballocations2nd.empty()); + m_2ndVectorMode = SECOND_VECTOR_RING_BUFFER; + break; + case SECOND_VECTOR_RING_BUFFER: + // 2-part ring buffer is already started. + VMA_ASSERT(!suballocations2nd.empty()); + break; + case SECOND_VECTOR_DOUBLE_STACK: + VMA_ASSERT(0 && "CRITICAL ERROR: Trying to use linear allocator as ring buffer while it was already used as double stack."); + break; + default: + VMA_ASSERT(0); + } + + suballocations2nd.push_back(newSuballoc); + } + break; + default: + VMA_ASSERT(0 && "CRITICAL INTERNAL ERROR."); + } + + m_SumFreeSize -= newSuballoc.size; +} + +void VmaBlockMetadata_Linear::Free(const VmaAllocation allocation) +{ + FreeAtOffset(allocation->GetOffset()); +} + +void VmaBlockMetadata_Linear::FreeAtOffset(VkDeviceSize offset) +{ + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + if(!suballocations1st.empty()) + { + // First allocation: Mark it as next empty at the beginning. + VmaSuballocation& firstSuballoc = suballocations1st[m_1stNullItemsBeginCount]; + if(firstSuballoc.offset == offset) + { + firstSuballoc.type = VMA_SUBALLOCATION_TYPE_FREE; + firstSuballoc.hAllocation = VK_NULL_HANDLE; + m_SumFreeSize += firstSuballoc.size; + ++m_1stNullItemsBeginCount; + CleanupAfterFree(); + return; + } + } + + // Last allocation in 2-part ring buffer or top of upper stack (same logic). + if(m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER || + m_2ndVectorMode == SECOND_VECTOR_DOUBLE_STACK) + { + VmaSuballocation& lastSuballoc = suballocations2nd.back(); + if(lastSuballoc.offset == offset) + { + m_SumFreeSize += lastSuballoc.size; + suballocations2nd.pop_back(); + CleanupAfterFree(); + return; + } + } + // Last allocation in 1st vector. + else if(m_2ndVectorMode == SECOND_VECTOR_EMPTY) + { + VmaSuballocation& lastSuballoc = suballocations1st.back(); + if(lastSuballoc.offset == offset) + { + m_SumFreeSize += lastSuballoc.size; + suballocations1st.pop_back(); + CleanupAfterFree(); + return; + } + } + + // Item from the middle of 1st vector. + { + VmaSuballocation refSuballoc; + refSuballoc.offset = offset; + // Rest of members stays uninitialized intentionally for better performance. + SuballocationVectorType::iterator it = VmaBinaryFindSorted( + suballocations1st.begin() + m_1stNullItemsBeginCount, + suballocations1st.end(), + refSuballoc, + VmaSuballocationOffsetLess()); + if(it != suballocations1st.end()) + { + it->type = VMA_SUBALLOCATION_TYPE_FREE; + it->hAllocation = VK_NULL_HANDLE; + ++m_1stNullItemsMiddleCount; + m_SumFreeSize += it->size; + CleanupAfterFree(); + return; + } + } + + if(m_2ndVectorMode != SECOND_VECTOR_EMPTY) + { + // Item from the middle of 2nd vector. + VmaSuballocation refSuballoc; + refSuballoc.offset = offset; + // Rest of members stays uninitialized intentionally for better performance. + SuballocationVectorType::iterator it = m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER ? + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetLess()) : + VmaBinaryFindSorted(suballocations2nd.begin(), suballocations2nd.end(), refSuballoc, VmaSuballocationOffsetGreater()); + if(it != suballocations2nd.end()) + { + it->type = VMA_SUBALLOCATION_TYPE_FREE; + it->hAllocation = VK_NULL_HANDLE; + ++m_2ndNullItemsCount; + m_SumFreeSize += it->size; + CleanupAfterFree(); + return; + } + } + + VMA_ASSERT(0 && "Allocation to free not found in linear allocator!"); +} + +bool VmaBlockMetadata_Linear::ShouldCompact1st() const +{ + const size_t nullItemCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; + const size_t suballocCount = AccessSuballocations1st().size(); + return suballocCount > 32 && nullItemCount * 2 >= (suballocCount - nullItemCount) * 3; +} + +void VmaBlockMetadata_Linear::CleanupAfterFree() +{ + SuballocationVectorType& suballocations1st = AccessSuballocations1st(); + SuballocationVectorType& suballocations2nd = AccessSuballocations2nd(); + + if(IsEmpty()) + { + suballocations1st.clear(); + suballocations2nd.clear(); + m_1stNullItemsBeginCount = 0; + m_1stNullItemsMiddleCount = 0; + m_2ndNullItemsCount = 0; + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + } + else + { + const size_t suballoc1stCount = suballocations1st.size(); + const size_t nullItem1stCount = m_1stNullItemsBeginCount + m_1stNullItemsMiddleCount; + VMA_ASSERT(nullItem1stCount <= suballoc1stCount); + + // Find more null items at the beginning of 1st vector. + while(m_1stNullItemsBeginCount < suballoc1stCount && + suballocations1st[m_1stNullItemsBeginCount].hAllocation == VK_NULL_HANDLE) + { + ++m_1stNullItemsBeginCount; + --m_1stNullItemsMiddleCount; + } + + // Find more null items at the end of 1st vector. + while(m_1stNullItemsMiddleCount > 0 && + suballocations1st.back().hAllocation == VK_NULL_HANDLE) + { + --m_1stNullItemsMiddleCount; + suballocations1st.pop_back(); + } + + // Find more null items at the end of 2nd vector. + while(m_2ndNullItemsCount > 0 && + suballocations2nd.back().hAllocation == VK_NULL_HANDLE) + { + --m_2ndNullItemsCount; + suballocations2nd.pop_back(); + } + + // Find more null items at the beginning of 2nd vector. + while(m_2ndNullItemsCount > 0 && + suballocations2nd[0].hAllocation == VK_NULL_HANDLE) + { + --m_2ndNullItemsCount; + VmaVectorRemove(suballocations2nd, 0); + } + + if(ShouldCompact1st()) + { + const size_t nonNullItemCount = suballoc1stCount - nullItem1stCount; + size_t srcIndex = m_1stNullItemsBeginCount; + for(size_t dstIndex = 0; dstIndex < nonNullItemCount; ++dstIndex) + { + while(suballocations1st[srcIndex].hAllocation == VK_NULL_HANDLE) + { + ++srcIndex; + } + if(dstIndex != srcIndex) + { + suballocations1st[dstIndex] = suballocations1st[srcIndex]; + } + ++srcIndex; + } + suballocations1st.resize(nonNullItemCount); + m_1stNullItemsBeginCount = 0; + m_1stNullItemsMiddleCount = 0; + } + + // 2nd vector became empty. + if(suballocations2nd.empty()) + { + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + } + + // 1st vector became empty. + if(suballocations1st.size() - m_1stNullItemsBeginCount == 0) + { + suballocations1st.clear(); + m_1stNullItemsBeginCount = 0; + + if(!suballocations2nd.empty() && m_2ndVectorMode == SECOND_VECTOR_RING_BUFFER) + { + // Swap 1st with 2nd. Now 2nd is empty. + m_2ndVectorMode = SECOND_VECTOR_EMPTY; + m_1stNullItemsMiddleCount = m_2ndNullItemsCount; + while(m_1stNullItemsBeginCount < suballocations2nd.size() && + suballocations2nd[m_1stNullItemsBeginCount].hAllocation == VK_NULL_HANDLE) + { + ++m_1stNullItemsBeginCount; + --m_1stNullItemsMiddleCount; + } + m_2ndNullItemsCount = 0; + m_1stVectorIndex ^= 1; + } + } + } + + VMA_HEAVY_ASSERT(Validate()); +} + + +//////////////////////////////////////////////////////////////////////////////// +// class VmaBlockMetadata_Buddy + +VmaBlockMetadata_Buddy::VmaBlockMetadata_Buddy(VmaAllocator hAllocator) : + VmaBlockMetadata(hAllocator), + m_Root(VMA_NULL), + m_AllocationCount(0), + m_FreeCount(1), + m_SumFreeSize(0) +{ + memset(m_FreeList, 0, sizeof(m_FreeList)); +} + +VmaBlockMetadata_Buddy::~VmaBlockMetadata_Buddy() +{ + DeleteNode(m_Root); +} + +void VmaBlockMetadata_Buddy::Init(VkDeviceSize size) +{ + VmaBlockMetadata::Init(size); + + m_UsableSize = VmaPrevPow2(size); + m_SumFreeSize = m_UsableSize; + + // Calculate m_LevelCount. + m_LevelCount = 1; + while(m_LevelCount < MAX_LEVELS && + LevelToNodeSize(m_LevelCount) >= MIN_NODE_SIZE) + { + ++m_LevelCount; + } + + Node* rootNode = vma_new(GetAllocationCallbacks(), Node)(); + rootNode->offset = 0; + rootNode->type = Node::TYPE_FREE; + rootNode->parent = VMA_NULL; + rootNode->buddy = VMA_NULL; + + m_Root = rootNode; + AddToFreeListFront(0, rootNode); +} + +bool VmaBlockMetadata_Buddy::Validate() const +{ + // Validate tree. + ValidationContext ctx; + if(!ValidateNode(ctx, VMA_NULL, m_Root, 0, LevelToNodeSize(0))) + { + VMA_VALIDATE(false && "ValidateNode failed."); + } + VMA_VALIDATE(m_AllocationCount == ctx.calculatedAllocationCount); + VMA_VALIDATE(m_SumFreeSize == ctx.calculatedSumFreeSize); + + // Validate free node lists. + for(uint32_t level = 0; level < m_LevelCount; ++level) + { + VMA_VALIDATE(m_FreeList[level].front == VMA_NULL || + m_FreeList[level].front->free.prev == VMA_NULL); + + for(Node* node = m_FreeList[level].front; + node != VMA_NULL; + node = node->free.next) + { + VMA_VALIDATE(node->type == Node::TYPE_FREE); + + if(node->free.next == VMA_NULL) + { + VMA_VALIDATE(m_FreeList[level].back == node); + } + else + { + VMA_VALIDATE(node->free.next->free.prev == node); + } + } + } + + // Validate that free lists ar higher levels are empty. + for(uint32_t level = m_LevelCount; level < MAX_LEVELS; ++level) + { + VMA_VALIDATE(m_FreeList[level].front == VMA_NULL && m_FreeList[level].back == VMA_NULL); + } + + return true; +} + +VkDeviceSize VmaBlockMetadata_Buddy::GetUnusedRangeSizeMax() const +{ + for(uint32_t level = 0; level < m_LevelCount; ++level) + { + if(m_FreeList[level].front != VMA_NULL) + { + return LevelToNodeSize(level); + } + } + return 0; +} + +void VmaBlockMetadata_Buddy::CalcAllocationStatInfo(VmaStatInfo& outInfo) const +{ + const VkDeviceSize unusableSize = GetUnusableSize(); + + outInfo.blockCount = 1; + + outInfo.allocationCount = outInfo.unusedRangeCount = 0; + outInfo.usedBytes = outInfo.unusedBytes = 0; + + outInfo.allocationSizeMax = outInfo.unusedRangeSizeMax = 0; + outInfo.allocationSizeMin = outInfo.unusedRangeSizeMin = UINT64_MAX; + outInfo.allocationSizeAvg = outInfo.unusedRangeSizeAvg = 0; // Unused. + + CalcAllocationStatInfoNode(outInfo, m_Root, LevelToNodeSize(0)); + + if(unusableSize > 0) + { + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusableSize; + outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, unusableSize); + outInfo.unusedRangeSizeMin = VMA_MIN(outInfo.unusedRangeSizeMin, unusableSize); + } +} + +void VmaBlockMetadata_Buddy::AddPoolStats(VmaPoolStats& inoutStats) const +{ + const VkDeviceSize unusableSize = GetUnusableSize(); + + inoutStats.size += GetSize(); + inoutStats.unusedSize += m_SumFreeSize + unusableSize; + inoutStats.allocationCount += m_AllocationCount; + inoutStats.unusedRangeCount += m_FreeCount; + inoutStats.unusedRangeSizeMax = VMA_MAX(inoutStats.unusedRangeSizeMax, GetUnusedRangeSizeMax()); + + if(unusableSize > 0) + { + ++inoutStats.unusedRangeCount; + // Not updating inoutStats.unusedRangeSizeMax with unusableSize because this space is not available for allocations. + } +} + +#if VMA_STATS_STRING_ENABLED + +void VmaBlockMetadata_Buddy::PrintDetailedMap(class VmaJsonWriter& json) const +{ + // TODO optimize + VmaStatInfo stat; + CalcAllocationStatInfo(stat); + + PrintDetailedMap_Begin( + json, + stat.unusedBytes, + stat.allocationCount, + stat.unusedRangeCount); + + PrintDetailedMapNode(json, m_Root, LevelToNodeSize(0)); + + const VkDeviceSize unusableSize = GetUnusableSize(); + if(unusableSize > 0) + { + PrintDetailedMap_UnusedRange(json, + m_UsableSize, // offset + unusableSize); // size + } + + PrintDetailedMap_End(json); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +bool VmaBlockMetadata_Buddy::CreateAllocationRequest( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VkDeviceSize bufferImageGranularity, + VkDeviceSize allocSize, + VkDeviceSize allocAlignment, + bool upperAddress, + VmaSuballocationType allocType, + bool canMakeOtherLost, + uint32_t strategy, + VmaAllocationRequest* pAllocationRequest) +{ + VMA_ASSERT(!upperAddress && "VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT can be used only with linear algorithm."); + + // Simple way to respect bufferImageGranularity. May be optimized some day. + // Whenever it might be an OPTIMAL image... + if(allocType == VMA_SUBALLOCATION_TYPE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN || + allocType == VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL) + { + allocAlignment = VMA_MAX(allocAlignment, bufferImageGranularity); + allocSize = VMA_MAX(allocSize, bufferImageGranularity); + } + + if(allocSize > m_UsableSize) + { + return false; + } + + const uint32_t targetLevel = AllocSizeToLevel(allocSize); + for(uint32_t level = targetLevel + 1; level--; ) + { + for(Node* freeNode = m_FreeList[level].front; + freeNode != VMA_NULL; + freeNode = freeNode->free.next) + { + if(freeNode->offset % allocAlignment == 0) + { + pAllocationRequest->type = VmaAllocationRequestType::Normal; + pAllocationRequest->offset = freeNode->offset; + pAllocationRequest->sumFreeSize = LevelToNodeSize(level); + pAllocationRequest->sumItemSize = 0; + pAllocationRequest->itemsToMakeLostCount = 0; + pAllocationRequest->customData = (void*)(uintptr_t)level; + return true; + } + } + } + + return false; +} + +bool VmaBlockMetadata_Buddy::MakeRequestedAllocationsLost( + uint32_t currentFrameIndex, + uint32_t frameInUseCount, + VmaAllocationRequest* pAllocationRequest) +{ + /* + Lost allocations are not supported in buddy allocator at the moment. + Support might be added in the future. + */ + return pAllocationRequest->itemsToMakeLostCount == 0; +} + +uint32_t VmaBlockMetadata_Buddy::MakeAllocationsLost(uint32_t currentFrameIndex, uint32_t frameInUseCount) +{ + /* + Lost allocations are not supported in buddy allocator at the moment. + Support might be added in the future. + */ + return 0; +} + +void VmaBlockMetadata_Buddy::Alloc( + const VmaAllocationRequest& request, + VmaSuballocationType type, + VkDeviceSize allocSize, + VmaAllocation hAllocation) +{ + VMA_ASSERT(request.type == VmaAllocationRequestType::Normal); + + const uint32_t targetLevel = AllocSizeToLevel(allocSize); + uint32_t currLevel = (uint32_t)(uintptr_t)request.customData; + + Node* currNode = m_FreeList[currLevel].front; + VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); + while(currNode->offset != request.offset) + { + currNode = currNode->free.next; + VMA_ASSERT(currNode != VMA_NULL && currNode->type == Node::TYPE_FREE); + } + + // Go down, splitting free nodes. + while(currLevel < targetLevel) + { + // currNode is already first free node at currLevel. + // Remove it from list of free nodes at this currLevel. + RemoveFromFreeList(currLevel, currNode); + + const uint32_t childrenLevel = currLevel + 1; + + // Create two free sub-nodes. + Node* leftChild = vma_new(GetAllocationCallbacks(), Node)(); + Node* rightChild = vma_new(GetAllocationCallbacks(), Node)(); + + leftChild->offset = currNode->offset; + leftChild->type = Node::TYPE_FREE; + leftChild->parent = currNode; + leftChild->buddy = rightChild; + + rightChild->offset = currNode->offset + LevelToNodeSize(childrenLevel); + rightChild->type = Node::TYPE_FREE; + rightChild->parent = currNode; + rightChild->buddy = leftChild; + + // Convert current currNode to split type. + currNode->type = Node::TYPE_SPLIT; + currNode->split.leftChild = leftChild; + + // Add child nodes to free list. Order is important! + AddToFreeListFront(childrenLevel, rightChild); + AddToFreeListFront(childrenLevel, leftChild); + + ++m_FreeCount; + //m_SumFreeSize -= LevelToNodeSize(currLevel) % 2; // Useful only when level node sizes can be non power of 2. + ++currLevel; + currNode = m_FreeList[currLevel].front; + + /* + We can be sure that currNode, as left child of node previously split, + also fullfills the alignment requirement. + */ + } + + // Remove from free list. + VMA_ASSERT(currLevel == targetLevel && + currNode != VMA_NULL && + currNode->type == Node::TYPE_FREE); + RemoveFromFreeList(currLevel, currNode); + + // Convert to allocation node. + currNode->type = Node::TYPE_ALLOCATION; + currNode->allocation.alloc = hAllocation; + + ++m_AllocationCount; + --m_FreeCount; + m_SumFreeSize -= allocSize; +} + +void VmaBlockMetadata_Buddy::DeleteNode(Node* node) +{ + if(node->type == Node::TYPE_SPLIT) + { + DeleteNode(node->split.leftChild->buddy); + DeleteNode(node->split.leftChild); + } + + vma_delete(GetAllocationCallbacks(), node); +} + +bool VmaBlockMetadata_Buddy::ValidateNode(ValidationContext& ctx, const Node* parent, const Node* curr, uint32_t level, VkDeviceSize levelNodeSize) const +{ + VMA_VALIDATE(level < m_LevelCount); + VMA_VALIDATE(curr->parent == parent); + VMA_VALIDATE((curr->buddy == VMA_NULL) == (parent == VMA_NULL)); + VMA_VALIDATE(curr->buddy == VMA_NULL || curr->buddy->buddy == curr); + switch(curr->type) + { + case Node::TYPE_FREE: + // curr->free.prev, next are validated separately. + ctx.calculatedSumFreeSize += levelNodeSize; + ++ctx.calculatedFreeCount; + break; + case Node::TYPE_ALLOCATION: + ++ctx.calculatedAllocationCount; + ctx.calculatedSumFreeSize += levelNodeSize - curr->allocation.alloc->GetSize(); + VMA_VALIDATE(curr->allocation.alloc != VK_NULL_HANDLE); + break; + case Node::TYPE_SPLIT: + { + const uint32_t childrenLevel = level + 1; + const VkDeviceSize childrenLevelNodeSize = levelNodeSize / 2; + const Node* const leftChild = curr->split.leftChild; + VMA_VALIDATE(leftChild != VMA_NULL); + VMA_VALIDATE(leftChild->offset == curr->offset); + if(!ValidateNode(ctx, curr, leftChild, childrenLevel, childrenLevelNodeSize)) + { + VMA_VALIDATE(false && "ValidateNode for left child failed."); + } + const Node* const rightChild = leftChild->buddy; + VMA_VALIDATE(rightChild->offset == curr->offset + childrenLevelNodeSize); + if(!ValidateNode(ctx, curr, rightChild, childrenLevel, childrenLevelNodeSize)) + { + VMA_VALIDATE(false && "ValidateNode for right child failed."); + } + } + break; + default: + return false; + } + + return true; +} + +uint32_t VmaBlockMetadata_Buddy::AllocSizeToLevel(VkDeviceSize allocSize) const +{ + // I know this could be optimized somehow e.g. by using std::log2p1 from C++20. + uint32_t level = 0; + VkDeviceSize currLevelNodeSize = m_UsableSize; + VkDeviceSize nextLevelNodeSize = currLevelNodeSize >> 1; + while(allocSize <= nextLevelNodeSize && level + 1 < m_LevelCount) + { + ++level; + currLevelNodeSize = nextLevelNodeSize; + nextLevelNodeSize = currLevelNodeSize >> 1; + } + return level; +} + +void VmaBlockMetadata_Buddy::FreeAtOffset(VmaAllocation alloc, VkDeviceSize offset) +{ + // Find node and level. + Node* node = m_Root; + VkDeviceSize nodeOffset = 0; + uint32_t level = 0; + VkDeviceSize levelNodeSize = LevelToNodeSize(0); + while(node->type == Node::TYPE_SPLIT) + { + const VkDeviceSize nextLevelSize = levelNodeSize >> 1; + if(offset < nodeOffset + nextLevelSize) + { + node = node->split.leftChild; + } + else + { + node = node->split.leftChild->buddy; + nodeOffset += nextLevelSize; + } + ++level; + levelNodeSize = nextLevelSize; + } + + VMA_ASSERT(node != VMA_NULL && node->type == Node::TYPE_ALLOCATION); + VMA_ASSERT(alloc == VK_NULL_HANDLE || node->allocation.alloc == alloc); + + ++m_FreeCount; + --m_AllocationCount; + m_SumFreeSize += alloc->GetSize(); + + node->type = Node::TYPE_FREE; + + // Join free nodes if possible. + while(level > 0 && node->buddy->type == Node::TYPE_FREE) + { + RemoveFromFreeList(level, node->buddy); + Node* const parent = node->parent; + + vma_delete(GetAllocationCallbacks(), node->buddy); + vma_delete(GetAllocationCallbacks(), node); + parent->type = Node::TYPE_FREE; + + node = parent; + --level; + //m_SumFreeSize += LevelToNodeSize(level) % 2; // Useful only when level node sizes can be non power of 2. + --m_FreeCount; + } + + AddToFreeListFront(level, node); +} + +void VmaBlockMetadata_Buddy::CalcAllocationStatInfoNode(VmaStatInfo& outInfo, const Node* node, VkDeviceSize levelNodeSize) const +{ + switch(node->type) + { + case Node::TYPE_FREE: + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += levelNodeSize; + outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, levelNodeSize); + outInfo.unusedRangeSizeMin = VMA_MAX(outInfo.unusedRangeSizeMin, levelNodeSize); + break; + case Node::TYPE_ALLOCATION: + { + const VkDeviceSize allocSize = node->allocation.alloc->GetSize(); + ++outInfo.allocationCount; + outInfo.usedBytes += allocSize; + outInfo.allocationSizeMax = VMA_MAX(outInfo.allocationSizeMax, allocSize); + outInfo.allocationSizeMin = VMA_MAX(outInfo.allocationSizeMin, allocSize); + + const VkDeviceSize unusedRangeSize = levelNodeSize - allocSize; + if(unusedRangeSize > 0) + { + ++outInfo.unusedRangeCount; + outInfo.unusedBytes += unusedRangeSize; + outInfo.unusedRangeSizeMax = VMA_MAX(outInfo.unusedRangeSizeMax, unusedRangeSize); + outInfo.unusedRangeSizeMin = VMA_MAX(outInfo.unusedRangeSizeMin, unusedRangeSize); + } + } + break; + case Node::TYPE_SPLIT: + { + const VkDeviceSize childrenNodeSize = levelNodeSize / 2; + const Node* const leftChild = node->split.leftChild; + CalcAllocationStatInfoNode(outInfo, leftChild, childrenNodeSize); + const Node* const rightChild = leftChild->buddy; + CalcAllocationStatInfoNode(outInfo, rightChild, childrenNodeSize); + } + break; + default: + VMA_ASSERT(0); + } +} + +void VmaBlockMetadata_Buddy::AddToFreeListFront(uint32_t level, Node* node) +{ + VMA_ASSERT(node->type == Node::TYPE_FREE); + + // List is empty. + Node* const frontNode = m_FreeList[level].front; + if(frontNode == VMA_NULL) + { + VMA_ASSERT(m_FreeList[level].back == VMA_NULL); + node->free.prev = node->free.next = VMA_NULL; + m_FreeList[level].front = m_FreeList[level].back = node; + } + else + { + VMA_ASSERT(frontNode->free.prev == VMA_NULL); + node->free.prev = VMA_NULL; + node->free.next = frontNode; + frontNode->free.prev = node; + m_FreeList[level].front = node; + } +} + +void VmaBlockMetadata_Buddy::RemoveFromFreeList(uint32_t level, Node* node) +{ + VMA_ASSERT(m_FreeList[level].front != VMA_NULL); + + // It is at the front. + if(node->free.prev == VMA_NULL) + { + VMA_ASSERT(m_FreeList[level].front == node); + m_FreeList[level].front = node->free.next; + } + else + { + Node* const prevFreeNode = node->free.prev; + VMA_ASSERT(prevFreeNode->free.next == node); + prevFreeNode->free.next = node->free.next; + } + + // It is at the back. + if(node->free.next == VMA_NULL) + { + VMA_ASSERT(m_FreeList[level].back == node); + m_FreeList[level].back = node->free.prev; + } + else + { + Node* const nextFreeNode = node->free.next; + VMA_ASSERT(nextFreeNode->free.prev == node); + nextFreeNode->free.prev = node->free.prev; + } +} + +#if VMA_STATS_STRING_ENABLED +void VmaBlockMetadata_Buddy::PrintDetailedMapNode(class VmaJsonWriter& json, const Node* node, VkDeviceSize levelNodeSize) const +{ + switch(node->type) + { + case Node::TYPE_FREE: + PrintDetailedMap_UnusedRange(json, node->offset, levelNodeSize); + break; + case Node::TYPE_ALLOCATION: + { + PrintDetailedMap_Allocation(json, node->offset, node->allocation.alloc); + const VkDeviceSize allocSize = node->allocation.alloc->GetSize(); + if(allocSize < levelNodeSize) + { + PrintDetailedMap_UnusedRange(json, node->offset + allocSize, levelNodeSize - allocSize); + } + } + break; + case Node::TYPE_SPLIT: + { + const VkDeviceSize childrenNodeSize = levelNodeSize / 2; + const Node* const leftChild = node->split.leftChild; + PrintDetailedMapNode(json, leftChild, childrenNodeSize); + const Node* const rightChild = leftChild->buddy; + PrintDetailedMapNode(json, rightChild, childrenNodeSize); + } + break; + default: + VMA_ASSERT(0); + } +} +#endif // #if VMA_STATS_STRING_ENABLED + + +//////////////////////////////////////////////////////////////////////////////// +// class VmaDeviceMemoryBlock + +VmaDeviceMemoryBlock::VmaDeviceMemoryBlock(VmaAllocator hAllocator) : + m_pMetadata(VMA_NULL), + m_MemoryTypeIndex(UINT32_MAX), + m_Id(0), + m_hMemory(VK_NULL_HANDLE), + m_MapCount(0), + m_pMappedData(VMA_NULL) +{ +} + +void VmaDeviceMemoryBlock::Init( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t newMemoryTypeIndex, + VkDeviceMemory newMemory, + VkDeviceSize newSize, + uint32_t id, + uint32_t algorithm) +{ + VMA_ASSERT(m_hMemory == VK_NULL_HANDLE); + + m_hParentPool = hParentPool; + m_MemoryTypeIndex = newMemoryTypeIndex; + m_Id = id; + m_hMemory = newMemory; + + switch(algorithm) + { + case VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT: + m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Linear)(hAllocator); + break; + case VMA_POOL_CREATE_BUDDY_ALGORITHM_BIT: + m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Buddy)(hAllocator); + break; + default: + VMA_ASSERT(0); + // Fall-through. + case 0: + m_pMetadata = vma_new(hAllocator, VmaBlockMetadata_Generic)(hAllocator); + } + m_pMetadata->Init(newSize); +} + +void VmaDeviceMemoryBlock::Destroy(VmaAllocator allocator) +{ + // This is the most important assert in the entire library. + // Hitting it means you have some memory leak - unreleased VmaAllocation objects. + VMA_ASSERT(m_pMetadata->IsEmpty() && "Some allocations were not freed before destruction of this memory block!"); + + VMA_ASSERT(m_hMemory != VK_NULL_HANDLE); + allocator->FreeVulkanMemory(m_MemoryTypeIndex, m_pMetadata->GetSize(), m_hMemory); + m_hMemory = VK_NULL_HANDLE; + + vma_delete(allocator, m_pMetadata); + m_pMetadata = VMA_NULL; +} + +bool VmaDeviceMemoryBlock::Validate() const +{ + VMA_VALIDATE((m_hMemory != VK_NULL_HANDLE) && + (m_pMetadata->GetSize() != 0)); + + return m_pMetadata->Validate(); +} + +VkResult VmaDeviceMemoryBlock::CheckCorruption(VmaAllocator hAllocator) +{ + void* pData = nullptr; + VkResult res = Map(hAllocator, 1, &pData); + if(res != VK_SUCCESS) + { + return res; + } + + res = m_pMetadata->CheckCorruption(pData); + + Unmap(hAllocator, 1); + + return res; +} + +VkResult VmaDeviceMemoryBlock::Map(VmaAllocator hAllocator, uint32_t count, void** ppData) +{ + if(count == 0) + { + return VK_SUCCESS; + } + + VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); + if(m_MapCount != 0) + { + m_MapCount += count; + VMA_ASSERT(m_pMappedData != VMA_NULL); + if(ppData != VMA_NULL) + { + *ppData = m_pMappedData; + } + return VK_SUCCESS; + } + else + { + VkResult result = (*hAllocator->GetVulkanFunctions().vkMapMemory)( + hAllocator->m_hDevice, + m_hMemory, + 0, // offset + VK_WHOLE_SIZE, + 0, // flags + &m_pMappedData); + if(result == VK_SUCCESS) + { + if(ppData != VMA_NULL) + { + *ppData = m_pMappedData; + } + m_MapCount = count; + } + return result; + } +} + +void VmaDeviceMemoryBlock::Unmap(VmaAllocator hAllocator, uint32_t count) +{ + if(count == 0) + { + return; + } + + VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); + if(m_MapCount >= count) + { + m_MapCount -= count; + if(m_MapCount == 0) + { + m_pMappedData = VMA_NULL; + (*hAllocator->GetVulkanFunctions().vkUnmapMemory)(hAllocator->m_hDevice, m_hMemory); + } + } + else + { + VMA_ASSERT(0 && "VkDeviceMemory block is being unmapped while it was not previously mapped."); + } +} + +VkResult VmaDeviceMemoryBlock::WriteMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +{ + VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); + VMA_ASSERT(allocOffset >= VMA_DEBUG_MARGIN); + + void* pData; + VkResult res = Map(hAllocator, 1, &pData); + if(res != VK_SUCCESS) + { + return res; + } + + VmaWriteMagicValue(pData, allocOffset - VMA_DEBUG_MARGIN); + VmaWriteMagicValue(pData, allocOffset + allocSize); + + Unmap(hAllocator, 1); + + return VK_SUCCESS; +} + +VkResult VmaDeviceMemoryBlock::ValidateMagicValueAroundAllocation(VmaAllocator hAllocator, VkDeviceSize allocOffset, VkDeviceSize allocSize) +{ + VMA_ASSERT(VMA_DEBUG_MARGIN > 0 && VMA_DEBUG_MARGIN % 4 == 0 && VMA_DEBUG_DETECT_CORRUPTION); + VMA_ASSERT(allocOffset >= VMA_DEBUG_MARGIN); + + void* pData; + VkResult res = Map(hAllocator, 1, &pData); + if(res != VK_SUCCESS) + { + return res; + } + + if(!VmaValidateMagicValue(pData, allocOffset - VMA_DEBUG_MARGIN)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED BEFORE FREED ALLOCATION!"); + } + else if(!VmaValidateMagicValue(pData, allocOffset + allocSize)) + { + VMA_ASSERT(0 && "MEMORY CORRUPTION DETECTED AFTER FREED ALLOCATION!"); + } + + Unmap(hAllocator, 1); + + return VK_SUCCESS; +} + +VkResult VmaDeviceMemoryBlock::BindBufferMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext) +{ + VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && + hAllocation->GetBlock() == this); + VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && + "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); + const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; + // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. + VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); + return hAllocator->BindVulkanBuffer(m_hMemory, memoryOffset, hBuffer, pNext); +} + +VkResult VmaDeviceMemoryBlock::BindImageMemory( + const VmaAllocator hAllocator, + const VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext) +{ + VMA_ASSERT(hAllocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK && + hAllocation->GetBlock() == this); + VMA_ASSERT(allocationLocalOffset < hAllocation->GetSize() && + "Invalid allocationLocalOffset. Did you forget that this offset is relative to the beginning of the allocation, not the whole memory block?"); + const VkDeviceSize memoryOffset = hAllocation->GetOffset() + allocationLocalOffset; + // This lock is important so that we don't call vkBind... and/or vkMap... simultaneously on the same VkDeviceMemory from multiple threads. + VmaMutexLock lock(m_Mutex, hAllocator->m_UseMutex); + return hAllocator->BindVulkanImage(m_hMemory, memoryOffset, hImage, pNext); +} + +static void InitStatInfo(VmaStatInfo& outInfo) +{ + memset(&outInfo, 0, sizeof(outInfo)); + outInfo.allocationSizeMin = UINT64_MAX; + outInfo.unusedRangeSizeMin = UINT64_MAX; +} + +// Adds statistics srcInfo into inoutInfo, like: inoutInfo += srcInfo. +static void VmaAddStatInfo(VmaStatInfo& inoutInfo, const VmaStatInfo& srcInfo) +{ + inoutInfo.blockCount += srcInfo.blockCount; + inoutInfo.allocationCount += srcInfo.allocationCount; + inoutInfo.unusedRangeCount += srcInfo.unusedRangeCount; + inoutInfo.usedBytes += srcInfo.usedBytes; + inoutInfo.unusedBytes += srcInfo.unusedBytes; + inoutInfo.allocationSizeMin = VMA_MIN(inoutInfo.allocationSizeMin, srcInfo.allocationSizeMin); + inoutInfo.allocationSizeMax = VMA_MAX(inoutInfo.allocationSizeMax, srcInfo.allocationSizeMax); + inoutInfo.unusedRangeSizeMin = VMA_MIN(inoutInfo.unusedRangeSizeMin, srcInfo.unusedRangeSizeMin); + inoutInfo.unusedRangeSizeMax = VMA_MAX(inoutInfo.unusedRangeSizeMax, srcInfo.unusedRangeSizeMax); +} + +static void VmaPostprocessCalcStatInfo(VmaStatInfo& inoutInfo) +{ + inoutInfo.allocationSizeAvg = (inoutInfo.allocationCount > 0) ? + VmaRoundDiv(inoutInfo.usedBytes, inoutInfo.allocationCount) : 0; + inoutInfo.unusedRangeSizeAvg = (inoutInfo.unusedRangeCount > 0) ? + VmaRoundDiv(inoutInfo.unusedBytes, inoutInfo.unusedRangeCount) : 0; +} + +VmaPool_T::VmaPool_T( + VmaAllocator hAllocator, + const VmaPoolCreateInfo& createInfo, + VkDeviceSize preferredBlockSize) : + m_BlockVector( + hAllocator, + this, // hParentPool + createInfo.memoryTypeIndex, + createInfo.blockSize != 0 ? createInfo.blockSize : preferredBlockSize, + createInfo.minBlockCount, + createInfo.maxBlockCount, + (createInfo.flags & VMA_POOL_CREATE_IGNORE_BUFFER_IMAGE_GRANULARITY_BIT) != 0 ? 1 : hAllocator->GetBufferImageGranularity(), + createInfo.frameInUseCount, + createInfo.blockSize != 0, // explicitBlockSize + createInfo.flags & VMA_POOL_CREATE_ALGORITHM_MASK), // algorithm + m_Id(0), + m_Name(VMA_NULL) +{ +} + +VmaPool_T::~VmaPool_T() +{ +} + +void VmaPool_T::SetName(const char* pName) +{ + const VkAllocationCallbacks* allocs = m_BlockVector.GetAllocator()->GetAllocationCallbacks(); + VmaFreeString(allocs, m_Name); + + if(pName != VMA_NULL) + { + m_Name = VmaCreateStringCopy(allocs, pName); + } + else + { + m_Name = VMA_NULL; + } +} + +#if VMA_STATS_STRING_ENABLED + +#endif // #if VMA_STATS_STRING_ENABLED + +VmaBlockVector::VmaBlockVector( + VmaAllocator hAllocator, + VmaPool hParentPool, + uint32_t memoryTypeIndex, + VkDeviceSize preferredBlockSize, + size_t minBlockCount, + size_t maxBlockCount, + VkDeviceSize bufferImageGranularity, + uint32_t frameInUseCount, + bool explicitBlockSize, + uint32_t algorithm) : + m_hAllocator(hAllocator), + m_hParentPool(hParentPool), + m_MemoryTypeIndex(memoryTypeIndex), + m_PreferredBlockSize(preferredBlockSize), + m_MinBlockCount(minBlockCount), + m_MaxBlockCount(maxBlockCount), + m_BufferImageGranularity(bufferImageGranularity), + m_FrameInUseCount(frameInUseCount), + m_ExplicitBlockSize(explicitBlockSize), + m_Algorithm(algorithm), + m_HasEmptyBlock(false), + m_Blocks(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_NextBlockId(0) +{ +} + +VmaBlockVector::~VmaBlockVector() +{ + for(size_t i = m_Blocks.size(); i--; ) + { + m_Blocks[i]->Destroy(m_hAllocator); + vma_delete(m_hAllocator, m_Blocks[i]); + } +} + +VkResult VmaBlockVector::CreateMinBlocks() +{ + for(size_t i = 0; i < m_MinBlockCount; ++i) + { + VkResult res = CreateBlock(m_PreferredBlockSize, VMA_NULL); + if(res != VK_SUCCESS) + { + return res; + } + } + return VK_SUCCESS; +} + +void VmaBlockVector::GetPoolStats(VmaPoolStats* pStats) +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + const size_t blockCount = m_Blocks.size(); + + pStats->size = 0; + pStats->unusedSize = 0; + pStats->allocationCount = 0; + pStats->unusedRangeCount = 0; + pStats->unusedRangeSizeMax = 0; + pStats->blockCount = blockCount; + + for(uint32_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VMA_HEAVY_ASSERT(pBlock->Validate()); + pBlock->m_pMetadata->AddPoolStats(*pStats); + } +} + +bool VmaBlockVector::IsEmpty() +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + return m_Blocks.empty(); +} + +bool VmaBlockVector::IsCorruptionDetectionEnabled() const +{ + const uint32_t requiredMemFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + return (VMA_DEBUG_DETECT_CORRUPTION != 0) && + (VMA_DEBUG_MARGIN > 0) && + (m_Algorithm == 0 || m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) && + (m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags & requiredMemFlags) == requiredMemFlags; +} + +static const uint32_t VMA_ALLOCATION_TRY_COUNT = 32; + +VkResult VmaBlockVector::Allocate( + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + size_t allocIndex; + VkResult res = VK_SUCCESS; + + if(IsCorruptionDetectionEnabled()) + { + size = VmaAlignUp(size, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); + alignment = VmaAlignUp(alignment, sizeof(VMA_CORRUPTION_DETECTION_MAGIC_VALUE)); + } + + { + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + res = AllocatePage( + currentFrameIndex, + size, + alignment, + createInfo, + suballocType, + pAllocations + allocIndex); + if(res != VK_SUCCESS) + { + break; + } + } + } + + if(res != VK_SUCCESS) + { + // Free all already created allocations. + while(allocIndex--) + { + Free(pAllocations[allocIndex]); + } + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + } + + return res; +} + +VkResult VmaBlockVector::AllocatePage( + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + VmaAllocation* pAllocation) +{ + const bool isUpperAddress = (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; + bool canMakeOtherLost = (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT) != 0; + const bool mapped = (createInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0; + const bool isUserDataString = (createInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0; + + VkDeviceSize freeMemory; + { + const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); + VmaBudget heapBudget = {}; + m_hAllocator->GetBudget(&heapBudget, heapIndex, 1); + freeMemory = (heapBudget.usage < heapBudget.budget) ? (heapBudget.budget - heapBudget.usage) : 0; + } + + const bool canFallbackToDedicated = !IsCustomPool(); + const bool canCreateNewBlock = + ((createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0) && + (m_Blocks.size() < m_MaxBlockCount) && + (freeMemory >= size || !canFallbackToDedicated); + uint32_t strategy = createInfo.flags & VMA_ALLOCATION_CREATE_STRATEGY_MASK; + + // If linearAlgorithm is used, canMakeOtherLost is available only when used as ring buffer. + // Which in turn is available only when maxBlockCount = 1. + if(m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT && m_MaxBlockCount > 1) + { + canMakeOtherLost = false; + } + + // Upper address can only be used with linear allocator and within single memory block. + if(isUpperAddress && + (m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT || m_MaxBlockCount > 1)) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + // Validate strategy. + switch(strategy) + { + case 0: + strategy = VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT; + break; + case VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT: + case VMA_ALLOCATION_CREATE_STRATEGY_WORST_FIT_BIT: + case VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT: + break; + default: + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + // Early reject: requested allocation size is larger that maximum block size for this block vector. + if(size + 2 * VMA_DEBUG_MARGIN > m_PreferredBlockSize) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + + /* + Under certain condition, this whole section can be skipped for optimization, so + we move on directly to trying to allocate with canMakeOtherLost. That's the case + e.g. for custom pools with linear algorithm. + */ + if(!canMakeOtherLost || canCreateNewBlock) + { + // 1. Search existing allocations. Try to allocate without making other allocations lost. + VmaAllocationCreateFlags allocFlagsCopy = createInfo.flags; + allocFlagsCopy &= ~VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT; + + if(m_Algorithm == VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + { + // Use only last block. + if(!m_Blocks.empty()) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks.back(); + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock( + pCurrBlock, + currentFrameIndex, + size, + alignment, + allocFlagsCopy, + createInfo.pUserData, + suballocType, + strategy, + pAllocation); + if(res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from last block #%u", pCurrBlock->GetId()); + return VK_SUCCESS; + } + } + } + else + { + if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) + { + // Forward order in m_Blocks - prefer blocks with smallest amount of free space. + for(size_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex ) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock( + pCurrBlock, + currentFrameIndex, + size, + alignment, + allocFlagsCopy, + createInfo.pUserData, + suballocType, + strategy, + pAllocation); + if(res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); + return VK_SUCCESS; + } + } + } + else // WORST_FIT, FIRST_FIT + { + // Backward order in m_Blocks - prefer blocks with largest amount of free space. + for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VkResult res = AllocateFromBlock( + pCurrBlock, + currentFrameIndex, + size, + alignment, + allocFlagsCopy, + createInfo.pUserData, + suballocType, + strategy, + pAllocation); + if(res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Returned from existing block #%u", pCurrBlock->GetId()); + return VK_SUCCESS; + } + } + } + } + + // 2. Try to create new block. + if(canCreateNewBlock) + { + // Calculate optimal size for new block. + VkDeviceSize newBlockSize = m_PreferredBlockSize; + uint32_t newBlockSizeShift = 0; + const uint32_t NEW_BLOCK_SIZE_SHIFT_MAX = 3; + + if(!m_ExplicitBlockSize) + { + // Allocate 1/8, 1/4, 1/2 as first blocks. + const VkDeviceSize maxExistingBlockSize = CalcMaxBlockSize(); + for(uint32_t i = 0; i < NEW_BLOCK_SIZE_SHIFT_MAX; ++i) + { + const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; + if(smallerNewBlockSize > maxExistingBlockSize && smallerNewBlockSize >= size * 2) + { + newBlockSize = smallerNewBlockSize; + ++newBlockSizeShift; + } + else + { + break; + } + } + } + + size_t newBlockIndex = 0; + VkResult res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? + CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; + // Allocation of this size failed? Try 1/2, 1/4, 1/8 of m_PreferredBlockSize. + if(!m_ExplicitBlockSize) + { + while(res < 0 && newBlockSizeShift < NEW_BLOCK_SIZE_SHIFT_MAX) + { + const VkDeviceSize smallerNewBlockSize = newBlockSize / 2; + if(smallerNewBlockSize >= size) + { + newBlockSize = smallerNewBlockSize; + ++newBlockSizeShift; + res = (newBlockSize <= freeMemory || !canFallbackToDedicated) ? + CreateBlock(newBlockSize, &newBlockIndex) : VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + else + { + break; + } + } + } + + if(res == VK_SUCCESS) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[newBlockIndex]; + VMA_ASSERT(pBlock->m_pMetadata->GetSize() >= size); + + res = AllocateFromBlock( + pBlock, + currentFrameIndex, + size, + alignment, + allocFlagsCopy, + createInfo.pUserData, + suballocType, + strategy, + pAllocation); + if(res == VK_SUCCESS) + { + VMA_DEBUG_LOG(" Created new block #%u Size=%llu", pBlock->GetId(), newBlockSize); + return VK_SUCCESS; + } + else + { + // Allocation from new block failed, possibly due to VMA_DEBUG_MARGIN or alignment. + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + } + } + } + + // 3. Try to allocate from existing blocks with making other allocations lost. + if(canMakeOtherLost) + { + uint32_t tryIndex = 0; + for(; tryIndex < VMA_ALLOCATION_TRY_COUNT; ++tryIndex) + { + VmaDeviceMemoryBlock* pBestRequestBlock = VMA_NULL; + VmaAllocationRequest bestRequest = {}; + VkDeviceSize bestRequestCost = VK_WHOLE_SIZE; + + // 1. Search existing allocations. + if(strategy == VMA_ALLOCATION_CREATE_STRATEGY_BEST_FIT_BIT) + { + // Forward order in m_Blocks - prefer blocks with smallest amount of free space. + for(size_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex ) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VmaAllocationRequest currRequest = {}; + if(pCurrBlock->m_pMetadata->CreateAllocationRequest( + currentFrameIndex, + m_FrameInUseCount, + m_BufferImageGranularity, + size, + alignment, + (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0, + suballocType, + canMakeOtherLost, + strategy, + &currRequest)) + { + const VkDeviceSize currRequestCost = currRequest.CalcCost(); + if(pBestRequestBlock == VMA_NULL || + currRequestCost < bestRequestCost) + { + pBestRequestBlock = pCurrBlock; + bestRequest = currRequest; + bestRequestCost = currRequestCost; + + if(bestRequestCost == 0) + { + break; + } + } + } + } + } + else // WORST_FIT, FIRST_FIT + { + // Backward order in m_Blocks - prefer blocks with largest amount of free space. + for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) + { + VmaDeviceMemoryBlock* const pCurrBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pCurrBlock); + VmaAllocationRequest currRequest = {}; + if(pCurrBlock->m_pMetadata->CreateAllocationRequest( + currentFrameIndex, + m_FrameInUseCount, + m_BufferImageGranularity, + size, + alignment, + (createInfo.flags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0, + suballocType, + canMakeOtherLost, + strategy, + &currRequest)) + { + const VkDeviceSize currRequestCost = currRequest.CalcCost(); + if(pBestRequestBlock == VMA_NULL || + currRequestCost < bestRequestCost || + strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) + { + pBestRequestBlock = pCurrBlock; + bestRequest = currRequest; + bestRequestCost = currRequestCost; + + if(bestRequestCost == 0 || + strategy == VMA_ALLOCATION_CREATE_STRATEGY_FIRST_FIT_BIT) + { + break; + } + } + } + } + } + + if(pBestRequestBlock != VMA_NULL) + { + if(mapped) + { + VkResult res = pBestRequestBlock->Map(m_hAllocator, 1, VMA_NULL); + if(res != VK_SUCCESS) + { + return res; + } + } + + if(pBestRequestBlock->m_pMetadata->MakeRequestedAllocationsLost( + currentFrameIndex, + m_FrameInUseCount, + &bestRequest)) + { + // Allocate from this pBlock. + *pAllocation = m_hAllocator->m_AllocationObjectAllocator.Allocate(currentFrameIndex, isUserDataString); + pBestRequestBlock->m_pMetadata->Alloc(bestRequest, suballocType, size, *pAllocation); + UpdateHasEmptyBlock(); + (*pAllocation)->InitBlockAllocation( + pBestRequestBlock, + bestRequest.offset, + alignment, + size, + m_MemoryTypeIndex, + suballocType, + mapped, + (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0); + VMA_HEAVY_ASSERT(pBestRequestBlock->Validate()); + VMA_DEBUG_LOG(" Returned from existing block"); + (*pAllocation)->SetUserData(m_hAllocator, createInfo.pUserData); + m_hAllocator->m_Budget.AddAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), size); + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + m_hAllocator->FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + } + if(IsCorruptionDetectionEnabled()) + { + VkResult res = pBestRequestBlock->WriteMagicValueAroundAllocation(m_hAllocator, bestRequest.offset, size); + VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to write magic value."); + } + return VK_SUCCESS; + } + // else: Some allocations must have been touched while we are here. Next try. + } + else + { + // Could not find place in any of the blocks - break outer loop. + break; + } + } + /* Maximum number of tries exceeded - a very unlike event when many other + threads are simultaneously touching allocations making it impossible to make + lost at the same time as we try to allocate. */ + if(tryIndex == VMA_ALLOCATION_TRY_COUNT) + { + return VK_ERROR_TOO_MANY_OBJECTS; + } + } + + return VK_ERROR_OUT_OF_DEVICE_MEMORY; +} + +void VmaBlockVector::Free( + const VmaAllocation hAllocation) +{ + VmaDeviceMemoryBlock* pBlockToDelete = VMA_NULL; + + bool budgetExceeded = false; + { + const uint32_t heapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex); + VmaBudget heapBudget = {}; + m_hAllocator->GetBudget(&heapBudget, heapIndex, 1); + budgetExceeded = heapBudget.usage >= heapBudget.budget; + } + + // Scope for lock. + { + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + + VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); + + if(IsCorruptionDetectionEnabled()) + { + VkResult res = pBlock->ValidateMagicValueAroundAllocation(m_hAllocator, hAllocation->GetOffset(), hAllocation->GetSize()); + VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to validate magic value."); + } + + if(hAllocation->IsPersistentMap()) + { + pBlock->Unmap(m_hAllocator, 1); + } + + pBlock->m_pMetadata->Free(hAllocation); + VMA_HEAVY_ASSERT(pBlock->Validate()); + + VMA_DEBUG_LOG(" Freed from MemoryTypeIndex=%u", m_MemoryTypeIndex); + + const bool canDeleteBlock = m_Blocks.size() > m_MinBlockCount; + // pBlock became empty after this deallocation. + if(pBlock->m_pMetadata->IsEmpty()) + { + // Already has empty block. We don't want to have two, so delete this one. + if((m_HasEmptyBlock || budgetExceeded) && canDeleteBlock) + { + pBlockToDelete = pBlock; + Remove(pBlock); + } + // else: We now have an empty block - leave it. + } + // pBlock didn't become empty, but we have another empty block - find and free that one. + // (This is optional, heuristics.) + else if(m_HasEmptyBlock && canDeleteBlock) + { + VmaDeviceMemoryBlock* pLastBlock = m_Blocks.back(); + if(pLastBlock->m_pMetadata->IsEmpty()) + { + pBlockToDelete = pLastBlock; + m_Blocks.pop_back(); + } + } + + UpdateHasEmptyBlock(); + IncrementallySortBlocks(); + } + + // Destruction of a free block. Deferred until this point, outside of mutex + // lock, for performance reason. + if(pBlockToDelete != VMA_NULL) + { + VMA_DEBUG_LOG(" Deleted empty block"); + pBlockToDelete->Destroy(m_hAllocator); + vma_delete(m_hAllocator, pBlockToDelete); + } +} + +VkDeviceSize VmaBlockVector::CalcMaxBlockSize() const +{ + VkDeviceSize result = 0; + for(size_t i = m_Blocks.size(); i--; ) + { + result = VMA_MAX(result, m_Blocks[i]->m_pMetadata->GetSize()); + if(result >= m_PreferredBlockSize) + { + break; + } + } + return result; +} + +void VmaBlockVector::Remove(VmaDeviceMemoryBlock* pBlock) +{ + for(uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + if(m_Blocks[blockIndex] == pBlock) + { + VmaVectorRemove(m_Blocks, blockIndex); + return; + } + } + VMA_ASSERT(0); +} + +void VmaBlockVector::IncrementallySortBlocks() +{ + if(m_Algorithm != VMA_POOL_CREATE_LINEAR_ALGORITHM_BIT) + { + // Bubble sort only until first swap. + for(size_t i = 1; i < m_Blocks.size(); ++i) + { + if(m_Blocks[i - 1]->m_pMetadata->GetSumFreeSize() > m_Blocks[i]->m_pMetadata->GetSumFreeSize()) + { + VMA_SWAP(m_Blocks[i - 1], m_Blocks[i]); + return; + } + } + } +} + +VkResult VmaBlockVector::AllocateFromBlock( + VmaDeviceMemoryBlock* pBlock, + uint32_t currentFrameIndex, + VkDeviceSize size, + VkDeviceSize alignment, + VmaAllocationCreateFlags allocFlags, + void* pUserData, + VmaSuballocationType suballocType, + uint32_t strategy, + VmaAllocation* pAllocation) +{ + VMA_ASSERT((allocFlags & VMA_ALLOCATION_CREATE_CAN_MAKE_OTHER_LOST_BIT) == 0); + const bool isUpperAddress = (allocFlags & VMA_ALLOCATION_CREATE_UPPER_ADDRESS_BIT) != 0; + const bool mapped = (allocFlags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0; + const bool isUserDataString = (allocFlags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0; + + VmaAllocationRequest currRequest = {}; + if(pBlock->m_pMetadata->CreateAllocationRequest( + currentFrameIndex, + m_FrameInUseCount, + m_BufferImageGranularity, + size, + alignment, + isUpperAddress, + suballocType, + false, // canMakeOtherLost + strategy, + &currRequest)) + { + // Allocate from pCurrBlock. + VMA_ASSERT(currRequest.itemsToMakeLostCount == 0); + + if(mapped) + { + VkResult res = pBlock->Map(m_hAllocator, 1, VMA_NULL); + if(res != VK_SUCCESS) + { + return res; + } + } + + *pAllocation = m_hAllocator->m_AllocationObjectAllocator.Allocate(currentFrameIndex, isUserDataString); + pBlock->m_pMetadata->Alloc(currRequest, suballocType, size, *pAllocation); + UpdateHasEmptyBlock(); + (*pAllocation)->InitBlockAllocation( + pBlock, + currRequest.offset, + alignment, + size, + m_MemoryTypeIndex, + suballocType, + mapped, + (allocFlags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0); + VMA_HEAVY_ASSERT(pBlock->Validate()); + (*pAllocation)->SetUserData(m_hAllocator, pUserData); + m_hAllocator->m_Budget.AddAllocation(m_hAllocator->MemoryTypeIndexToHeapIndex(m_MemoryTypeIndex), size); + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + m_hAllocator->FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + } + if(IsCorruptionDetectionEnabled()) + { + VkResult res = pBlock->WriteMagicValueAroundAllocation(m_hAllocator, currRequest.offset, size); + VMA_ASSERT(res == VK_SUCCESS && "Couldn't map block memory to write magic value."); + } + return VK_SUCCESS; + } + return VK_ERROR_OUT_OF_DEVICE_MEMORY; +} + +VkResult VmaBlockVector::CreateBlock(VkDeviceSize blockSize, size_t* pNewBlockIndex) +{ + VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; + allocInfo.memoryTypeIndex = m_MemoryTypeIndex; + allocInfo.allocationSize = blockSize; + +#if VMA_BUFFER_DEVICE_ADDRESS + // Every standalone block can potentially contain a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT - always enable the feature. + VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; + if(m_hAllocator->m_UseKhrBufferDeviceAddress) + { + allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; + VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); + } +#endif // #if VMA_BUFFER_DEVICE_ADDRESS + + VkDeviceMemory mem = VK_NULL_HANDLE; + VkResult res = m_hAllocator->AllocateVulkanMemory(&allocInfo, &mem); + if(res < 0) + { + return res; + } + + // New VkDeviceMemory successfully created. + + // Create new Allocation for it. + VmaDeviceMemoryBlock* const pBlock = vma_new(m_hAllocator, VmaDeviceMemoryBlock)(m_hAllocator); + pBlock->Init( + m_hAllocator, + m_hParentPool, + m_MemoryTypeIndex, + mem, + allocInfo.allocationSize, + m_NextBlockId++, + m_Algorithm); + + m_Blocks.push_back(pBlock); + if(pNewBlockIndex != VMA_NULL) + { + *pNewBlockIndex = m_Blocks.size() - 1; + } + + return VK_SUCCESS; +} + +void VmaBlockVector::ApplyDefragmentationMovesCpu( + class VmaBlockVectorDefragmentationContext* pDefragCtx, + const VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves) +{ + const size_t blockCount = m_Blocks.size(); + const bool isNonCoherent = m_hAllocator->IsMemoryTypeNonCoherent(m_MemoryTypeIndex); + + enum BLOCK_FLAG + { + BLOCK_FLAG_USED = 0x00000001, + BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION = 0x00000002, + }; + + struct BlockInfo + { + uint32_t flags; + void* pMappedData; + }; + VmaVector< BlockInfo, VmaStlAllocator > + blockInfo(blockCount, BlockInfo(), VmaStlAllocator(m_hAllocator->GetAllocationCallbacks())); + memset(blockInfo.data(), 0, blockCount * sizeof(BlockInfo)); + + // Go over all moves. Mark blocks that are used with BLOCK_FLAG_USED. + const size_t moveCount = moves.size(); + for(size_t moveIndex = 0; moveIndex < moveCount; ++moveIndex) + { + const VmaDefragmentationMove& move = moves[moveIndex]; + blockInfo[move.srcBlockIndex].flags |= BLOCK_FLAG_USED; + blockInfo[move.dstBlockIndex].flags |= BLOCK_FLAG_USED; + } + + VMA_ASSERT(pDefragCtx->res == VK_SUCCESS); + + // Go over all blocks. Get mapped pointer or map if necessary. + for(size_t blockIndex = 0; pDefragCtx->res == VK_SUCCESS && blockIndex < blockCount; ++blockIndex) + { + BlockInfo& currBlockInfo = blockInfo[blockIndex]; + VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; + if((currBlockInfo.flags & BLOCK_FLAG_USED) != 0) + { + currBlockInfo.pMappedData = pBlock->GetMappedData(); + // It is not originally mapped - map it. + if(currBlockInfo.pMappedData == VMA_NULL) + { + pDefragCtx->res = pBlock->Map(m_hAllocator, 1, &currBlockInfo.pMappedData); + if(pDefragCtx->res == VK_SUCCESS) + { + currBlockInfo.flags |= BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION; + } + } + } + } + + // Go over all moves. Do actual data transfer. + if(pDefragCtx->res == VK_SUCCESS) + { + const VkDeviceSize nonCoherentAtomSize = m_hAllocator->m_PhysicalDeviceProperties.limits.nonCoherentAtomSize; + VkMappedMemoryRange memRange = { VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE }; + + for(size_t moveIndex = 0; moveIndex < moveCount; ++moveIndex) + { + const VmaDefragmentationMove& move = moves[moveIndex]; + + const BlockInfo& srcBlockInfo = blockInfo[move.srcBlockIndex]; + const BlockInfo& dstBlockInfo = blockInfo[move.dstBlockIndex]; + + VMA_ASSERT(srcBlockInfo.pMappedData && dstBlockInfo.pMappedData); + + // Invalidate source. + if(isNonCoherent) + { + VmaDeviceMemoryBlock* const pSrcBlock = m_Blocks[move.srcBlockIndex]; + memRange.memory = pSrcBlock->GetDeviceMemory(); + memRange.offset = VmaAlignDown(move.srcOffset, nonCoherentAtomSize); + memRange.size = VMA_MIN( + VmaAlignUp(move.size + (move.srcOffset - memRange.offset), nonCoherentAtomSize), + pSrcBlock->m_pMetadata->GetSize() - memRange.offset); + (*m_hAllocator->GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hAllocator->m_hDevice, 1, &memRange); + } + + // THE PLACE WHERE ACTUAL DATA COPY HAPPENS. + memmove( + reinterpret_cast(dstBlockInfo.pMappedData) + move.dstOffset, + reinterpret_cast(srcBlockInfo.pMappedData) + move.srcOffset, + static_cast(move.size)); + + if(IsCorruptionDetectionEnabled()) + { + VmaWriteMagicValue(dstBlockInfo.pMappedData, move.dstOffset - VMA_DEBUG_MARGIN); + VmaWriteMagicValue(dstBlockInfo.pMappedData, move.dstOffset + move.size); + } + + // Flush destination. + if(isNonCoherent) + { + VmaDeviceMemoryBlock* const pDstBlock = m_Blocks[move.dstBlockIndex]; + memRange.memory = pDstBlock->GetDeviceMemory(); + memRange.offset = VmaAlignDown(move.dstOffset, nonCoherentAtomSize); + memRange.size = VMA_MIN( + VmaAlignUp(move.size + (move.dstOffset - memRange.offset), nonCoherentAtomSize), + pDstBlock->m_pMetadata->GetSize() - memRange.offset); + (*m_hAllocator->GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hAllocator->m_hDevice, 1, &memRange); + } + } + } + + // Go over all blocks in reverse order. Unmap those that were mapped just for defragmentation. + // Regardless of pCtx->res == VK_SUCCESS. + for(size_t blockIndex = blockCount; blockIndex--; ) + { + const BlockInfo& currBlockInfo = blockInfo[blockIndex]; + if((currBlockInfo.flags & BLOCK_FLAG_MAPPED_FOR_DEFRAGMENTATION) != 0) + { + VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; + pBlock->Unmap(m_hAllocator, 1); + } + } +} + +void VmaBlockVector::ApplyDefragmentationMovesGpu( + class VmaBlockVectorDefragmentationContext* pDefragCtx, + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkCommandBuffer commandBuffer) +{ + const size_t blockCount = m_Blocks.size(); + + pDefragCtx->blockContexts.resize(blockCount); + memset(pDefragCtx->blockContexts.data(), 0, blockCount * sizeof(VmaBlockDefragmentationContext)); + + // Go over all moves. Mark blocks that are used with BLOCK_FLAG_USED. + const size_t moveCount = moves.size(); + for(size_t moveIndex = 0; moveIndex < moveCount; ++moveIndex) + { + const VmaDefragmentationMove& move = moves[moveIndex]; + + //if(move.type == VMA_ALLOCATION_TYPE_UNKNOWN) + { + // Old school move still require us to map the whole block + pDefragCtx->blockContexts[move.srcBlockIndex].flags |= VmaBlockDefragmentationContext::BLOCK_FLAG_USED; + pDefragCtx->blockContexts[move.dstBlockIndex].flags |= VmaBlockDefragmentationContext::BLOCK_FLAG_USED; + } + } + + VMA_ASSERT(pDefragCtx->res == VK_SUCCESS); + + // Go over all blocks. Create and bind buffer for whole block if necessary. + { + VkBufferCreateInfo bufCreateInfo; + VmaFillGpuDefragmentationBufferCreateInfo(bufCreateInfo); + + for(size_t blockIndex = 0; pDefragCtx->res == VK_SUCCESS && blockIndex < blockCount; ++blockIndex) + { + VmaBlockDefragmentationContext& currBlockCtx = pDefragCtx->blockContexts[blockIndex]; + VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; + if((currBlockCtx.flags & VmaBlockDefragmentationContext::BLOCK_FLAG_USED) != 0) + { + bufCreateInfo.size = pBlock->m_pMetadata->GetSize(); + pDefragCtx->res = (*m_hAllocator->GetVulkanFunctions().vkCreateBuffer)( + m_hAllocator->m_hDevice, &bufCreateInfo, m_hAllocator->GetAllocationCallbacks(), &currBlockCtx.hBuffer); + if(pDefragCtx->res == VK_SUCCESS) + { + pDefragCtx->res = (*m_hAllocator->GetVulkanFunctions().vkBindBufferMemory)( + m_hAllocator->m_hDevice, currBlockCtx.hBuffer, pBlock->GetDeviceMemory(), 0); + } + } + } + } + + // Go over all moves. Post data transfer commands to command buffer. + if(pDefragCtx->res == VK_SUCCESS) + { + for(size_t moveIndex = 0; moveIndex < moveCount; ++moveIndex) + { + const VmaDefragmentationMove& move = moves[moveIndex]; + + const VmaBlockDefragmentationContext& srcBlockCtx = pDefragCtx->blockContexts[move.srcBlockIndex]; + const VmaBlockDefragmentationContext& dstBlockCtx = pDefragCtx->blockContexts[move.dstBlockIndex]; + + VMA_ASSERT(srcBlockCtx.hBuffer && dstBlockCtx.hBuffer); + + VkBufferCopy region = { + move.srcOffset, + move.dstOffset, + move.size }; + (*m_hAllocator->GetVulkanFunctions().vkCmdCopyBuffer)( + commandBuffer, srcBlockCtx.hBuffer, dstBlockCtx.hBuffer, 1, ®ion); + } + } + + // Save buffers to defrag context for later destruction. + if(pDefragCtx->res == VK_SUCCESS && moveCount > 0) + { + pDefragCtx->res = VK_NOT_READY; + } +} + +void VmaBlockVector::FreeEmptyBlocks(VmaDefragmentationStats* pDefragmentationStats) +{ + for(size_t blockIndex = m_Blocks.size(); blockIndex--; ) + { + VmaDeviceMemoryBlock* pBlock = m_Blocks[blockIndex]; + if(pBlock->m_pMetadata->IsEmpty()) + { + if(m_Blocks.size() > m_MinBlockCount) + { + if(pDefragmentationStats != VMA_NULL) + { + ++pDefragmentationStats->deviceMemoryBlocksFreed; + pDefragmentationStats->bytesFreed += pBlock->m_pMetadata->GetSize(); + } + + VmaVectorRemove(m_Blocks, blockIndex); + pBlock->Destroy(m_hAllocator); + vma_delete(m_hAllocator, pBlock); + } + else + { + break; + } + } + } + UpdateHasEmptyBlock(); +} + +void VmaBlockVector::UpdateHasEmptyBlock() +{ + m_HasEmptyBlock = false; + for(size_t index = 0, count = m_Blocks.size(); index < count; ++index) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[index]; + if(pBlock->m_pMetadata->IsEmpty()) + { + m_HasEmptyBlock = true; + break; + } + } +} + +#if VMA_STATS_STRING_ENABLED + +void VmaBlockVector::PrintDetailedMap(class VmaJsonWriter& json) +{ + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + json.BeginObject(); + + if(IsCustomPool()) + { + const char* poolName = m_hParentPool->GetName(); + if(poolName != VMA_NULL && poolName[0] != '\0') + { + json.WriteString("Name"); + json.WriteString(poolName); + } + + json.WriteString("MemoryTypeIndex"); + json.WriteNumber(m_MemoryTypeIndex); + + json.WriteString("BlockSize"); + json.WriteNumber(m_PreferredBlockSize); + + json.WriteString("BlockCount"); + json.BeginObject(true); + if(m_MinBlockCount > 0) + { + json.WriteString("Min"); + json.WriteNumber((uint64_t)m_MinBlockCount); + } + if(m_MaxBlockCount < SIZE_MAX) + { + json.WriteString("Max"); + json.WriteNumber((uint64_t)m_MaxBlockCount); + } + json.WriteString("Cur"); + json.WriteNumber((uint64_t)m_Blocks.size()); + json.EndObject(); + + if(m_FrameInUseCount > 0) + { + json.WriteString("FrameInUseCount"); + json.WriteNumber(m_FrameInUseCount); + } + + if(m_Algorithm != 0) + { + json.WriteString("Algorithm"); + json.WriteString(VmaAlgorithmToStr(m_Algorithm)); + } + } + else + { + json.WriteString("PreferredBlockSize"); + json.WriteNumber(m_PreferredBlockSize); + } + + json.WriteString("Blocks"); + json.BeginObject(); + for(size_t i = 0; i < m_Blocks.size(); ++i) + { + json.BeginString(); + json.ContinueString(m_Blocks[i]->GetId()); + json.EndString(); + + m_Blocks[i]->m_pMetadata->PrintDetailedMap(json); + } + json.EndObject(); + + json.EndObject(); +} + +#endif // #if VMA_STATS_STRING_ENABLED + +void VmaBlockVector::Defragment( + class VmaBlockVectorDefragmentationContext* pCtx, + VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags, + VkDeviceSize& maxCpuBytesToMove, uint32_t& maxCpuAllocationsToMove, + VkDeviceSize& maxGpuBytesToMove, uint32_t& maxGpuAllocationsToMove, + VkCommandBuffer commandBuffer) +{ + pCtx->res = VK_SUCCESS; + + const VkMemoryPropertyFlags memPropFlags = + m_hAllocator->m_MemProps.memoryTypes[m_MemoryTypeIndex].propertyFlags; + const bool isHostVisible = (memPropFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0; + + const bool canDefragmentOnCpu = maxCpuBytesToMove > 0 && maxCpuAllocationsToMove > 0 && + isHostVisible; + const bool canDefragmentOnGpu = maxGpuBytesToMove > 0 && maxGpuAllocationsToMove > 0 && + !IsCorruptionDetectionEnabled() && + ((1u << m_MemoryTypeIndex) & m_hAllocator->GetGpuDefragmentationMemoryTypeBits()) != 0; + + // There are options to defragment this memory type. + if(canDefragmentOnCpu || canDefragmentOnGpu) + { + bool defragmentOnGpu; + // There is only one option to defragment this memory type. + if(canDefragmentOnGpu != canDefragmentOnCpu) + { + defragmentOnGpu = canDefragmentOnGpu; + } + // Both options are available: Heuristics to choose the best one. + else + { + defragmentOnGpu = (memPropFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) != 0 || + m_hAllocator->IsIntegratedGpu(); + } + + bool overlappingMoveSupported = !defragmentOnGpu; + + if(m_hAllocator->m_UseMutex) + { + if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) + { + if(!m_Mutex.TryLockWrite()) + { + pCtx->res = VK_ERROR_INITIALIZATION_FAILED; + return; + } + } + else + { + m_Mutex.LockWrite(); + pCtx->mutexLocked = true; + } + } + + pCtx->Begin(overlappingMoveSupported, flags); + + // Defragment. + + const VkDeviceSize maxBytesToMove = defragmentOnGpu ? maxGpuBytesToMove : maxCpuBytesToMove; + const uint32_t maxAllocationsToMove = defragmentOnGpu ? maxGpuAllocationsToMove : maxCpuAllocationsToMove; + pCtx->res = pCtx->GetAlgorithm()->Defragment(pCtx->defragmentationMoves, maxBytesToMove, maxAllocationsToMove, flags); + + // Accumulate statistics. + if(pStats != VMA_NULL) + { + const VkDeviceSize bytesMoved = pCtx->GetAlgorithm()->GetBytesMoved(); + const uint32_t allocationsMoved = pCtx->GetAlgorithm()->GetAllocationsMoved(); + pStats->bytesMoved += bytesMoved; + pStats->allocationsMoved += allocationsMoved; + VMA_ASSERT(bytesMoved <= maxBytesToMove); + VMA_ASSERT(allocationsMoved <= maxAllocationsToMove); + if(defragmentOnGpu) + { + maxGpuBytesToMove -= bytesMoved; + maxGpuAllocationsToMove -= allocationsMoved; + } + else + { + maxCpuBytesToMove -= bytesMoved; + maxCpuAllocationsToMove -= allocationsMoved; + } + } + + if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) + { + if(m_hAllocator->m_UseMutex) + m_Mutex.UnlockWrite(); + + if(pCtx->res >= VK_SUCCESS && !pCtx->defragmentationMoves.empty()) + pCtx->res = VK_NOT_READY; + + return; + } + + if(pCtx->res >= VK_SUCCESS) + { + if(defragmentOnGpu) + { + ApplyDefragmentationMovesGpu(pCtx, pCtx->defragmentationMoves, commandBuffer); + } + else + { + ApplyDefragmentationMovesCpu(pCtx, pCtx->defragmentationMoves); + } + } + } +} + +void VmaBlockVector::DefragmentationEnd( + class VmaBlockVectorDefragmentationContext* pCtx, + uint32_t flags, + VmaDefragmentationStats* pStats) +{ + if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL && m_hAllocator->m_UseMutex) + { + VMA_ASSERT(pCtx->mutexLocked == false); + + // Incremental defragmentation doesn't hold the lock, so when we enter here we don't actually have any + // lock protecting us. Since we mutate state here, we have to take the lock out now + m_Mutex.LockWrite(); + pCtx->mutexLocked = true; + } + + // If the mutex isn't locked we didn't do any work and there is nothing to delete. + if(pCtx->mutexLocked || !m_hAllocator->m_UseMutex) + { + // Destroy buffers. + for(size_t blockIndex = pCtx->blockContexts.size(); blockIndex--;) + { + VmaBlockDefragmentationContext &blockCtx = pCtx->blockContexts[blockIndex]; + if(blockCtx.hBuffer) + { + (*m_hAllocator->GetVulkanFunctions().vkDestroyBuffer)(m_hAllocator->m_hDevice, blockCtx.hBuffer, m_hAllocator->GetAllocationCallbacks()); + } + } + + if(pCtx->res >= VK_SUCCESS) + { + FreeEmptyBlocks(pStats); + } + } + + if(pCtx->mutexLocked) + { + VMA_ASSERT(m_hAllocator->m_UseMutex); + m_Mutex.UnlockWrite(); + } +} + +uint32_t VmaBlockVector::ProcessDefragmentations( + class VmaBlockVectorDefragmentationContext *pCtx, + VmaDefragmentationPassMoveInfo* pMove, uint32_t maxMoves) +{ + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + + const uint32_t moveCount = std::min(uint32_t(pCtx->defragmentationMoves.size()) - pCtx->defragmentationMovesProcessed, maxMoves); + + for(uint32_t i = 0; i < moveCount; ++ i) + { + VmaDefragmentationMove& move = pCtx->defragmentationMoves[pCtx->defragmentationMovesProcessed + i]; + + pMove->allocation = move.hAllocation; + pMove->memory = move.pDstBlock->GetDeviceMemory(); + pMove->offset = move.dstOffset; + + ++ pMove; + } + + pCtx->defragmentationMovesProcessed += moveCount; + + return moveCount; +} + +void VmaBlockVector::CommitDefragmentations( + class VmaBlockVectorDefragmentationContext *pCtx, + VmaDefragmentationStats* pStats) +{ + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + + for(uint32_t i = pCtx->defragmentationMovesCommitted; i < pCtx->defragmentationMovesProcessed; ++ i) + { + const VmaDefragmentationMove &move = pCtx->defragmentationMoves[i]; + + move.pSrcBlock->m_pMetadata->FreeAtOffset(move.srcOffset); + move.hAllocation->ChangeBlockAllocation(m_hAllocator, move.pDstBlock, move.dstOffset); + } + + pCtx->defragmentationMovesCommitted = pCtx->defragmentationMovesProcessed; + FreeEmptyBlocks(pStats); +} + +size_t VmaBlockVector::CalcAllocationCount() const +{ + size_t result = 0; + for(size_t i = 0; i < m_Blocks.size(); ++i) + { + result += m_Blocks[i]->m_pMetadata->GetAllocationCount(); + } + return result; +} + +bool VmaBlockVector::IsBufferImageGranularityConflictPossible() const +{ + if(m_BufferImageGranularity == 1) + { + return false; + } + VmaSuballocationType lastSuballocType = VMA_SUBALLOCATION_TYPE_FREE; + for(size_t i = 0, count = m_Blocks.size(); i < count; ++i) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[i]; + VMA_ASSERT(m_Algorithm == 0); + VmaBlockMetadata_Generic* const pMetadata = (VmaBlockMetadata_Generic*)pBlock->m_pMetadata; + if(pMetadata->IsBufferImageGranularityConflictPossible(m_BufferImageGranularity, lastSuballocType)) + { + return true; + } + } + return false; +} + +void VmaBlockVector::MakePoolAllocationsLost( + uint32_t currentFrameIndex, + size_t* pLostAllocationCount) +{ + VmaMutexLockWrite lock(m_Mutex, m_hAllocator->m_UseMutex); + size_t lostAllocationCount = 0; + for(uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + lostAllocationCount += pBlock->m_pMetadata->MakeAllocationsLost(currentFrameIndex, m_FrameInUseCount); + } + if(pLostAllocationCount != VMA_NULL) + { + *pLostAllocationCount = lostAllocationCount; + } +} + +VkResult VmaBlockVector::CheckCorruption() +{ + if(!IsCorruptionDetectionEnabled()) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + for(uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VkResult res = pBlock->CheckCorruption(m_hAllocator); + if(res != VK_SUCCESS) + { + return res; + } + } + return VK_SUCCESS; +} + +void VmaBlockVector::AddStats(VmaStats* pStats) +{ + const uint32_t memTypeIndex = m_MemoryTypeIndex; + const uint32_t memHeapIndex = m_hAllocator->MemoryTypeIndexToHeapIndex(memTypeIndex); + + VmaMutexLockRead lock(m_Mutex, m_hAllocator->m_UseMutex); + + for(uint32_t blockIndex = 0; blockIndex < m_Blocks.size(); ++blockIndex) + { + const VmaDeviceMemoryBlock* const pBlock = m_Blocks[blockIndex]; + VMA_ASSERT(pBlock); + VMA_HEAVY_ASSERT(pBlock->Validate()); + VmaStatInfo allocationStatInfo; + pBlock->m_pMetadata->CalcAllocationStatInfo(allocationStatInfo); + VmaAddStatInfo(pStats->total, allocationStatInfo); + VmaAddStatInfo(pStats->memoryType[memTypeIndex], allocationStatInfo); + VmaAddStatInfo(pStats->memoryHeap[memHeapIndex], allocationStatInfo); + } +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaDefragmentationAlgorithm_Generic members definition + +VmaDefragmentationAlgorithm_Generic::VmaDefragmentationAlgorithm_Generic( + VmaAllocator hAllocator, + VmaBlockVector* pBlockVector, + uint32_t currentFrameIndex, + bool overlappingMoveSupported) : + VmaDefragmentationAlgorithm(hAllocator, pBlockVector, currentFrameIndex), + m_AllocationCount(0), + m_AllAllocations(false), + m_BytesMoved(0), + m_AllocationsMoved(0), + m_Blocks(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +{ + // Create block info for each block. + const size_t blockCount = m_pBlockVector->m_Blocks.size(); + for(size_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + BlockInfo* pBlockInfo = vma_new(m_hAllocator, BlockInfo)(m_hAllocator->GetAllocationCallbacks()); + pBlockInfo->m_OriginalBlockIndex = blockIndex; + pBlockInfo->m_pBlock = m_pBlockVector->m_Blocks[blockIndex]; + m_Blocks.push_back(pBlockInfo); + } + + // Sort them by m_pBlock pointer value. + VMA_SORT(m_Blocks.begin(), m_Blocks.end(), BlockPointerLess()); +} + +VmaDefragmentationAlgorithm_Generic::~VmaDefragmentationAlgorithm_Generic() +{ + for(size_t i = m_Blocks.size(); i--; ) + { + vma_delete(m_hAllocator, m_Blocks[i]); + } +} + +void VmaDefragmentationAlgorithm_Generic::AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) +{ + // Now as we are inside VmaBlockVector::m_Mutex, we can make final check if this allocation was not lost. + if(hAlloc->GetLastUseFrameIndex() != VMA_FRAME_INDEX_LOST) + { + VmaDeviceMemoryBlock* pBlock = hAlloc->GetBlock(); + BlockInfoVector::iterator it = VmaBinaryFindFirstNotLess(m_Blocks.begin(), m_Blocks.end(), pBlock, BlockPointerLess()); + if(it != m_Blocks.end() && (*it)->m_pBlock == pBlock) + { + AllocationInfo allocInfo = AllocationInfo(hAlloc, pChanged); + (*it)->m_Allocations.push_back(allocInfo); + } + else + { + VMA_ASSERT(0); + } + + ++m_AllocationCount; + } +} + +VkResult VmaDefragmentationAlgorithm_Generic::DefragmentRound( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + bool freeOldAllocations) +{ + if(m_Blocks.empty()) + { + return VK_SUCCESS; + } + + // This is a choice based on research. + // Option 1: + uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_TIME_BIT; + // Option 2: + //uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_MEMORY_BIT; + // Option 3: + //uint32_t strategy = VMA_ALLOCATION_CREATE_STRATEGY_MIN_FRAGMENTATION_BIT; + + size_t srcBlockMinIndex = 0; + // When FAST_ALGORITHM, move allocations from only last out of blocks that contain non-movable allocations. + /* + if(m_AlgorithmFlags & VMA_DEFRAGMENTATION_FAST_ALGORITHM_BIT) + { + const size_t blocksWithNonMovableCount = CalcBlocksWithNonMovableCount(); + if(blocksWithNonMovableCount > 0) + { + srcBlockMinIndex = blocksWithNonMovableCount - 1; + } + } + */ + + size_t srcBlockIndex = m_Blocks.size() - 1; + size_t srcAllocIndex = SIZE_MAX; + for(;;) + { + // 1. Find next allocation to move. + // 1.1. Start from last to first m_Blocks - they are sorted from most "destination" to most "source". + // 1.2. Then start from last to first m_Allocations. + while(srcAllocIndex >= m_Blocks[srcBlockIndex]->m_Allocations.size()) + { + if(m_Blocks[srcBlockIndex]->m_Allocations.empty()) + { + // Finished: no more allocations to process. + if(srcBlockIndex == srcBlockMinIndex) + { + return VK_SUCCESS; + } + else + { + --srcBlockIndex; + srcAllocIndex = SIZE_MAX; + } + } + else + { + srcAllocIndex = m_Blocks[srcBlockIndex]->m_Allocations.size() - 1; + } + } + + BlockInfo* pSrcBlockInfo = m_Blocks[srcBlockIndex]; + AllocationInfo& allocInfo = pSrcBlockInfo->m_Allocations[srcAllocIndex]; + + const VkDeviceSize size = allocInfo.m_hAllocation->GetSize(); + const VkDeviceSize srcOffset = allocInfo.m_hAllocation->GetOffset(); + const VkDeviceSize alignment = allocInfo.m_hAllocation->GetAlignment(); + const VmaSuballocationType suballocType = allocInfo.m_hAllocation->GetSuballocationType(); + + // 2. Try to find new place for this allocation in preceding or current block. + for(size_t dstBlockIndex = 0; dstBlockIndex <= srcBlockIndex; ++dstBlockIndex) + { + BlockInfo* pDstBlockInfo = m_Blocks[dstBlockIndex]; + VmaAllocationRequest dstAllocRequest; + if(pDstBlockInfo->m_pBlock->m_pMetadata->CreateAllocationRequest( + m_CurrentFrameIndex, + m_pBlockVector->GetFrameInUseCount(), + m_pBlockVector->GetBufferImageGranularity(), + size, + alignment, + false, // upperAddress + suballocType, + false, // canMakeOtherLost + strategy, + &dstAllocRequest) && + MoveMakesSense( + dstBlockIndex, dstAllocRequest.offset, srcBlockIndex, srcOffset)) + { + VMA_ASSERT(dstAllocRequest.itemsToMakeLostCount == 0); + + // Reached limit on number of allocations or bytes to move. + if((m_AllocationsMoved + 1 > maxAllocationsToMove) || + (m_BytesMoved + size > maxBytesToMove)) + { + return VK_SUCCESS; + } + + VmaDefragmentationMove move = {}; + move.srcBlockIndex = pSrcBlockInfo->m_OriginalBlockIndex; + move.dstBlockIndex = pDstBlockInfo->m_OriginalBlockIndex; + move.srcOffset = srcOffset; + move.dstOffset = dstAllocRequest.offset; + move.size = size; + move.hAllocation = allocInfo.m_hAllocation; + move.pSrcBlock = pSrcBlockInfo->m_pBlock; + move.pDstBlock = pDstBlockInfo->m_pBlock; + + moves.push_back(move); + + pDstBlockInfo->m_pBlock->m_pMetadata->Alloc( + dstAllocRequest, + suballocType, + size, + allocInfo.m_hAllocation); + + if(freeOldAllocations) + { + pSrcBlockInfo->m_pBlock->m_pMetadata->FreeAtOffset(srcOffset); + allocInfo.m_hAllocation->ChangeBlockAllocation(m_hAllocator, pDstBlockInfo->m_pBlock, dstAllocRequest.offset); + } + + if(allocInfo.m_pChanged != VMA_NULL) + { + *allocInfo.m_pChanged = VK_TRUE; + } + + ++m_AllocationsMoved; + m_BytesMoved += size; + + VmaVectorRemove(pSrcBlockInfo->m_Allocations, srcAllocIndex); + + break; + } + } + + // If not processed, this allocInfo remains in pBlockInfo->m_Allocations for next round. + + if(srcAllocIndex > 0) + { + --srcAllocIndex; + } + else + { + if(srcBlockIndex > 0) + { + --srcBlockIndex; + srcAllocIndex = SIZE_MAX; + } + else + { + return VK_SUCCESS; + } + } + } +} + +size_t VmaDefragmentationAlgorithm_Generic::CalcBlocksWithNonMovableCount() const +{ + size_t result = 0; + for(size_t i = 0; i < m_Blocks.size(); ++i) + { + if(m_Blocks[i]->m_HasNonMovableAllocations) + { + ++result; + } + } + return result; +} + +VkResult VmaDefragmentationAlgorithm_Generic::Defragment( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + VmaDefragmentationFlags flags) +{ + if(!m_AllAllocations && m_AllocationCount == 0) + { + return VK_SUCCESS; + } + + const size_t blockCount = m_Blocks.size(); + for(size_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + BlockInfo* pBlockInfo = m_Blocks[blockIndex]; + + if(m_AllAllocations) + { + VmaBlockMetadata_Generic* pMetadata = (VmaBlockMetadata_Generic*)pBlockInfo->m_pBlock->m_pMetadata; + for(VmaSuballocationList::const_iterator it = pMetadata->m_Suballocations.begin(); + it != pMetadata->m_Suballocations.end(); + ++it) + { + if(it->type != VMA_SUBALLOCATION_TYPE_FREE) + { + AllocationInfo allocInfo = AllocationInfo(it->hAllocation, VMA_NULL); + pBlockInfo->m_Allocations.push_back(allocInfo); + } + } + } + + pBlockInfo->CalcHasNonMovableAllocations(); + + // This is a choice based on research. + // Option 1: + pBlockInfo->SortAllocationsByOffsetDescending(); + // Option 2: + //pBlockInfo->SortAllocationsBySizeDescending(); + } + + // Sort m_Blocks this time by the main criterium, from most "destination" to most "source" blocks. + VMA_SORT(m_Blocks.begin(), m_Blocks.end(), BlockInfoCompareMoveDestination()); + + // This is a choice based on research. + const uint32_t roundCount = 2; + + // Execute defragmentation rounds (the main part). + VkResult result = VK_SUCCESS; + for(uint32_t round = 0; (round < roundCount) && (result == VK_SUCCESS); ++round) + { + result = DefragmentRound(moves, maxBytesToMove, maxAllocationsToMove, !(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL)); + } + + return result; +} + +bool VmaDefragmentationAlgorithm_Generic::MoveMakesSense( + size_t dstBlockIndex, VkDeviceSize dstOffset, + size_t srcBlockIndex, VkDeviceSize srcOffset) +{ + if(dstBlockIndex < srcBlockIndex) + { + return true; + } + if(dstBlockIndex > srcBlockIndex) + { + return false; + } + if(dstOffset < srcOffset) + { + return true; + } + return false; +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaDefragmentationAlgorithm_Fast + +VmaDefragmentationAlgorithm_Fast::VmaDefragmentationAlgorithm_Fast( + VmaAllocator hAllocator, + VmaBlockVector* pBlockVector, + uint32_t currentFrameIndex, + bool overlappingMoveSupported) : + VmaDefragmentationAlgorithm(hAllocator, pBlockVector, currentFrameIndex), + m_OverlappingMoveSupported(overlappingMoveSupported), + m_AllocationCount(0), + m_AllAllocations(false), + m_BytesMoved(0), + m_AllocationsMoved(0), + m_BlockInfos(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +{ + VMA_ASSERT(VMA_DEBUG_MARGIN == 0); + +} + +VmaDefragmentationAlgorithm_Fast::~VmaDefragmentationAlgorithm_Fast() +{ +} + +VkResult VmaDefragmentationAlgorithm_Fast::Defragment( + VmaVector< VmaDefragmentationMove, VmaStlAllocator >& moves, + VkDeviceSize maxBytesToMove, + uint32_t maxAllocationsToMove, + VmaDefragmentationFlags flags) +{ + VMA_ASSERT(m_AllAllocations || m_pBlockVector->CalcAllocationCount() == m_AllocationCount); + + const size_t blockCount = m_pBlockVector->GetBlockCount(); + if(blockCount == 0 || maxBytesToMove == 0 || maxAllocationsToMove == 0) + { + return VK_SUCCESS; + } + + PreprocessMetadata(); + + // Sort blocks in order from most destination. + + m_BlockInfos.resize(blockCount); + for(size_t i = 0; i < blockCount; ++i) + { + m_BlockInfos[i].origBlockIndex = i; + } + + VMA_SORT(m_BlockInfos.begin(), m_BlockInfos.end(), [this](const BlockInfo& lhs, const BlockInfo& rhs) -> bool { + return m_pBlockVector->GetBlock(lhs.origBlockIndex)->m_pMetadata->GetSumFreeSize() < + m_pBlockVector->GetBlock(rhs.origBlockIndex)->m_pMetadata->GetSumFreeSize(); + }); + + // THE MAIN ALGORITHM + + FreeSpaceDatabase freeSpaceDb; + + size_t dstBlockInfoIndex = 0; + size_t dstOrigBlockIndex = m_BlockInfos[dstBlockInfoIndex].origBlockIndex; + VmaDeviceMemoryBlock* pDstBlock = m_pBlockVector->GetBlock(dstOrigBlockIndex); + VmaBlockMetadata_Generic* pDstMetadata = (VmaBlockMetadata_Generic*)pDstBlock->m_pMetadata; + VkDeviceSize dstBlockSize = pDstMetadata->GetSize(); + VkDeviceSize dstOffset = 0; + + bool end = false; + for(size_t srcBlockInfoIndex = 0; !end && srcBlockInfoIndex < blockCount; ++srcBlockInfoIndex) + { + const size_t srcOrigBlockIndex = m_BlockInfos[srcBlockInfoIndex].origBlockIndex; + VmaDeviceMemoryBlock* const pSrcBlock = m_pBlockVector->GetBlock(srcOrigBlockIndex); + VmaBlockMetadata_Generic* const pSrcMetadata = (VmaBlockMetadata_Generic*)pSrcBlock->m_pMetadata; + for(VmaSuballocationList::iterator srcSuballocIt = pSrcMetadata->m_Suballocations.begin(); + !end && srcSuballocIt != pSrcMetadata->m_Suballocations.end(); ) + { + VmaAllocation_T* const pAlloc = srcSuballocIt->hAllocation; + const VkDeviceSize srcAllocAlignment = pAlloc->GetAlignment(); + const VkDeviceSize srcAllocSize = srcSuballocIt->size; + if(m_AllocationsMoved == maxAllocationsToMove || + m_BytesMoved + srcAllocSize > maxBytesToMove) + { + end = true; + break; + } + const VkDeviceSize srcAllocOffset = srcSuballocIt->offset; + + VmaDefragmentationMove move = {}; + // Try to place it in one of free spaces from the database. + size_t freeSpaceInfoIndex; + VkDeviceSize dstAllocOffset; + if(freeSpaceDb.Fetch(srcAllocAlignment, srcAllocSize, + freeSpaceInfoIndex, dstAllocOffset)) + { + size_t freeSpaceOrigBlockIndex = m_BlockInfos[freeSpaceInfoIndex].origBlockIndex; + VmaDeviceMemoryBlock* pFreeSpaceBlock = m_pBlockVector->GetBlock(freeSpaceOrigBlockIndex); + VmaBlockMetadata_Generic* pFreeSpaceMetadata = (VmaBlockMetadata_Generic*)pFreeSpaceBlock->m_pMetadata; + + // Same block + if(freeSpaceInfoIndex == srcBlockInfoIndex) + { + VMA_ASSERT(dstAllocOffset <= srcAllocOffset); + + // MOVE OPTION 1: Move the allocation inside the same block by decreasing offset. + + VmaSuballocation suballoc = *srcSuballocIt; + suballoc.offset = dstAllocOffset; + suballoc.hAllocation->ChangeOffset(dstAllocOffset); + m_BytesMoved += srcAllocSize; + ++m_AllocationsMoved; + + VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; + ++nextSuballocIt; + pSrcMetadata->m_Suballocations.erase(srcSuballocIt); + srcSuballocIt = nextSuballocIt; + + InsertSuballoc(pFreeSpaceMetadata, suballoc); + + move.srcBlockIndex = srcOrigBlockIndex; + move.dstBlockIndex = freeSpaceOrigBlockIndex; + move.srcOffset = srcAllocOffset; + move.dstOffset = dstAllocOffset; + move.size = srcAllocSize; + + moves.push_back(move); + } + // Different block + else + { + // MOVE OPTION 2: Move the allocation to a different block. + + VMA_ASSERT(freeSpaceInfoIndex < srcBlockInfoIndex); + + VmaSuballocation suballoc = *srcSuballocIt; + suballoc.offset = dstAllocOffset; + suballoc.hAllocation->ChangeBlockAllocation(m_hAllocator, pFreeSpaceBlock, dstAllocOffset); + m_BytesMoved += srcAllocSize; + ++m_AllocationsMoved; + + VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; + ++nextSuballocIt; + pSrcMetadata->m_Suballocations.erase(srcSuballocIt); + srcSuballocIt = nextSuballocIt; + + InsertSuballoc(pFreeSpaceMetadata, suballoc); + + move.srcBlockIndex = srcOrigBlockIndex; + move.dstBlockIndex = freeSpaceOrigBlockIndex; + move.srcOffset = srcAllocOffset; + move.dstOffset = dstAllocOffset; + move.size = srcAllocSize; + + moves.push_back(move); + } + } + else + { + dstAllocOffset = VmaAlignUp(dstOffset, srcAllocAlignment); + + // If the allocation doesn't fit before the end of dstBlock, forward to next block. + while(dstBlockInfoIndex < srcBlockInfoIndex && + dstAllocOffset + srcAllocSize > dstBlockSize) + { + // But before that, register remaining free space at the end of dst block. + freeSpaceDb.Register(dstBlockInfoIndex, dstOffset, dstBlockSize - dstOffset); + + ++dstBlockInfoIndex; + dstOrigBlockIndex = m_BlockInfos[dstBlockInfoIndex].origBlockIndex; + pDstBlock = m_pBlockVector->GetBlock(dstOrigBlockIndex); + pDstMetadata = (VmaBlockMetadata_Generic*)pDstBlock->m_pMetadata; + dstBlockSize = pDstMetadata->GetSize(); + dstOffset = 0; + dstAllocOffset = 0; + } + + // Same block + if(dstBlockInfoIndex == srcBlockInfoIndex) + { + VMA_ASSERT(dstAllocOffset <= srcAllocOffset); + + const bool overlap = dstAllocOffset + srcAllocSize > srcAllocOffset; + + bool skipOver = overlap; + if(overlap && m_OverlappingMoveSupported && dstAllocOffset < srcAllocOffset) + { + // If destination and source place overlap, skip if it would move it + // by only < 1/64 of its size. + skipOver = (srcAllocOffset - dstAllocOffset) * 64 < srcAllocSize; + } + + if(skipOver) + { + freeSpaceDb.Register(dstBlockInfoIndex, dstOffset, srcAllocOffset - dstOffset); + + dstOffset = srcAllocOffset + srcAllocSize; + ++srcSuballocIt; + } + // MOVE OPTION 1: Move the allocation inside the same block by decreasing offset. + else + { + srcSuballocIt->offset = dstAllocOffset; + srcSuballocIt->hAllocation->ChangeOffset(dstAllocOffset); + dstOffset = dstAllocOffset + srcAllocSize; + m_BytesMoved += srcAllocSize; + ++m_AllocationsMoved; + ++srcSuballocIt; + + move.srcBlockIndex = srcOrigBlockIndex; + move.dstBlockIndex = dstOrigBlockIndex; + move.srcOffset = srcAllocOffset; + move.dstOffset = dstAllocOffset; + move.size = srcAllocSize; + + moves.push_back(move); + } + } + // Different block + else + { + // MOVE OPTION 2: Move the allocation to a different block. + + VMA_ASSERT(dstBlockInfoIndex < srcBlockInfoIndex); + VMA_ASSERT(dstAllocOffset + srcAllocSize <= dstBlockSize); + + VmaSuballocation suballoc = *srcSuballocIt; + suballoc.offset = dstAllocOffset; + suballoc.hAllocation->ChangeBlockAllocation(m_hAllocator, pDstBlock, dstAllocOffset); + dstOffset = dstAllocOffset + srcAllocSize; + m_BytesMoved += srcAllocSize; + ++m_AllocationsMoved; + + VmaSuballocationList::iterator nextSuballocIt = srcSuballocIt; + ++nextSuballocIt; + pSrcMetadata->m_Suballocations.erase(srcSuballocIt); + srcSuballocIt = nextSuballocIt; + + pDstMetadata->m_Suballocations.push_back(suballoc); + + move.srcBlockIndex = srcOrigBlockIndex; + move.dstBlockIndex = dstOrigBlockIndex; + move.srcOffset = srcAllocOffset; + move.dstOffset = dstAllocOffset; + move.size = srcAllocSize; + + moves.push_back(move); + } + } + } + } + + m_BlockInfos.clear(); + + PostprocessMetadata(); + + return VK_SUCCESS; +} + +void VmaDefragmentationAlgorithm_Fast::PreprocessMetadata() +{ + const size_t blockCount = m_pBlockVector->GetBlockCount(); + for(size_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + VmaBlockMetadata_Generic* const pMetadata = + (VmaBlockMetadata_Generic*)m_pBlockVector->GetBlock(blockIndex)->m_pMetadata; + pMetadata->m_FreeCount = 0; + pMetadata->m_SumFreeSize = pMetadata->GetSize(); + pMetadata->m_FreeSuballocationsBySize.clear(); + for(VmaSuballocationList::iterator it = pMetadata->m_Suballocations.begin(); + it != pMetadata->m_Suballocations.end(); ) + { + if(it->type == VMA_SUBALLOCATION_TYPE_FREE) + { + VmaSuballocationList::iterator nextIt = it; + ++nextIt; + pMetadata->m_Suballocations.erase(it); + it = nextIt; + } + else + { + ++it; + } + } + } +} + +void VmaDefragmentationAlgorithm_Fast::PostprocessMetadata() +{ + const size_t blockCount = m_pBlockVector->GetBlockCount(); + for(size_t blockIndex = 0; blockIndex < blockCount; ++blockIndex) + { + VmaBlockMetadata_Generic* const pMetadata = + (VmaBlockMetadata_Generic*)m_pBlockVector->GetBlock(blockIndex)->m_pMetadata; + const VkDeviceSize blockSize = pMetadata->GetSize(); + + // No allocations in this block - entire area is free. + if(pMetadata->m_Suballocations.empty()) + { + pMetadata->m_FreeCount = 1; + //pMetadata->m_SumFreeSize is already set to blockSize. + VmaSuballocation suballoc = { + 0, // offset + blockSize, // size + VMA_NULL, // hAllocation + VMA_SUBALLOCATION_TYPE_FREE }; + pMetadata->m_Suballocations.push_back(suballoc); + pMetadata->RegisterFreeSuballocation(pMetadata->m_Suballocations.begin()); + } + // There are some allocations in this block. + else + { + VkDeviceSize offset = 0; + VmaSuballocationList::iterator it; + for(it = pMetadata->m_Suballocations.begin(); + it != pMetadata->m_Suballocations.end(); + ++it) + { + VMA_ASSERT(it->type != VMA_SUBALLOCATION_TYPE_FREE); + VMA_ASSERT(it->offset >= offset); + + // Need to insert preceding free space. + if(it->offset > offset) + { + ++pMetadata->m_FreeCount; + const VkDeviceSize freeSize = it->offset - offset; + VmaSuballocation suballoc = { + offset, // offset + freeSize, // size + VMA_NULL, // hAllocation + VMA_SUBALLOCATION_TYPE_FREE }; + VmaSuballocationList::iterator precedingFreeIt = pMetadata->m_Suballocations.insert(it, suballoc); + if(freeSize >= VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + pMetadata->m_FreeSuballocationsBySize.push_back(precedingFreeIt); + } + } + + pMetadata->m_SumFreeSize -= it->size; + offset = it->offset + it->size; + } + + // Need to insert trailing free space. + if(offset < blockSize) + { + ++pMetadata->m_FreeCount; + const VkDeviceSize freeSize = blockSize - offset; + VmaSuballocation suballoc = { + offset, // offset + freeSize, // size + VMA_NULL, // hAllocation + VMA_SUBALLOCATION_TYPE_FREE }; + VMA_ASSERT(it == pMetadata->m_Suballocations.end()); + VmaSuballocationList::iterator trailingFreeIt = pMetadata->m_Suballocations.insert(it, suballoc); + if(freeSize > VMA_MIN_FREE_SUBALLOCATION_SIZE_TO_REGISTER) + { + pMetadata->m_FreeSuballocationsBySize.push_back(trailingFreeIt); + } + } + + VMA_SORT( + pMetadata->m_FreeSuballocationsBySize.begin(), + pMetadata->m_FreeSuballocationsBySize.end(), + VmaSuballocationItemSizeLess()); + } + + VMA_HEAVY_ASSERT(pMetadata->Validate()); + } +} + +void VmaDefragmentationAlgorithm_Fast::InsertSuballoc(VmaBlockMetadata_Generic* pMetadata, const VmaSuballocation& suballoc) +{ + // TODO: Optimize somehow. Remember iterator instead of searching for it linearly. + VmaSuballocationList::iterator it = pMetadata->m_Suballocations.begin(); + while(it != pMetadata->m_Suballocations.end()) + { + if(it->offset < suballoc.offset) + { + ++it; + } + } + pMetadata->m_Suballocations.insert(it, suballoc); +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaBlockVectorDefragmentationContext + +VmaBlockVectorDefragmentationContext::VmaBlockVectorDefragmentationContext( + VmaAllocator hAllocator, + VmaPool hCustomPool, + VmaBlockVector* pBlockVector, + uint32_t currFrameIndex) : + res(VK_SUCCESS), + mutexLocked(false), + blockContexts(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + defragmentationMoves(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + defragmentationMovesProcessed(0), + defragmentationMovesCommitted(0), + hasDefragmentationPlan(0), + m_hAllocator(hAllocator), + m_hCustomPool(hCustomPool), + m_pBlockVector(pBlockVector), + m_CurrFrameIndex(currFrameIndex), + m_pAlgorithm(VMA_NULL), + m_Allocations(VmaStlAllocator(hAllocator->GetAllocationCallbacks())), + m_AllAllocations(false) +{ +} + +VmaBlockVectorDefragmentationContext::~VmaBlockVectorDefragmentationContext() +{ + vma_delete(m_hAllocator, m_pAlgorithm); +} + +void VmaBlockVectorDefragmentationContext::AddAllocation(VmaAllocation hAlloc, VkBool32* pChanged) +{ + AllocInfo info = { hAlloc, pChanged }; + m_Allocations.push_back(info); +} + +void VmaBlockVectorDefragmentationContext::Begin(bool overlappingMoveSupported, VmaDefragmentationFlags flags) +{ + const bool allAllocations = m_AllAllocations || + m_Allocations.size() == m_pBlockVector->CalcAllocationCount(); + + /******************************** + HERE IS THE CHOICE OF DEFRAGMENTATION ALGORITHM. + ********************************/ + + /* + Fast algorithm is supported only when certain criteria are met: + - VMA_DEBUG_MARGIN is 0. + - All allocations in this block vector are moveable. + - There is no possibility of image/buffer granularity conflict. + - The defragmentation is not incremental + */ + if(VMA_DEBUG_MARGIN == 0 && + allAllocations && + !m_pBlockVector->IsBufferImageGranularityConflictPossible() && + !(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL)) + { + m_pAlgorithm = vma_new(m_hAllocator, VmaDefragmentationAlgorithm_Fast)( + m_hAllocator, m_pBlockVector, m_CurrFrameIndex, overlappingMoveSupported); + } + else + { + m_pAlgorithm = vma_new(m_hAllocator, VmaDefragmentationAlgorithm_Generic)( + m_hAllocator, m_pBlockVector, m_CurrFrameIndex, overlappingMoveSupported); + } + + if(allAllocations) + { + m_pAlgorithm->AddAll(); + } + else + { + for(size_t i = 0, count = m_Allocations.size(); i < count; ++i) + { + m_pAlgorithm->AddAllocation(m_Allocations[i].hAlloc, m_Allocations[i].pChanged); + } + } +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaDefragmentationContext + +VmaDefragmentationContext_T::VmaDefragmentationContext_T( + VmaAllocator hAllocator, + uint32_t currFrameIndex, + uint32_t flags, + VmaDefragmentationStats* pStats) : + m_hAllocator(hAllocator), + m_CurrFrameIndex(currFrameIndex), + m_Flags(flags), + m_pStats(pStats), + m_CustomPoolContexts(VmaStlAllocator(hAllocator->GetAllocationCallbacks())) +{ + memset(m_DefaultPoolContexts, 0, sizeof(m_DefaultPoolContexts)); +} + +VmaDefragmentationContext_T::~VmaDefragmentationContext_T() +{ + for(size_t i = m_CustomPoolContexts.size(); i--; ) + { + VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_CustomPoolContexts[i]; + pBlockVectorCtx->GetBlockVector()->DefragmentationEnd(pBlockVectorCtx, m_Flags, m_pStats); + vma_delete(m_hAllocator, pBlockVectorCtx); + } + for(size_t i = m_hAllocator->m_MemProps.memoryTypeCount; i--; ) + { + VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_DefaultPoolContexts[i]; + if(pBlockVectorCtx) + { + pBlockVectorCtx->GetBlockVector()->DefragmentationEnd(pBlockVectorCtx, m_Flags, m_pStats); + vma_delete(m_hAllocator, pBlockVectorCtx); + } + } +} + +void VmaDefragmentationContext_T::AddPools(uint32_t poolCount, const VmaPool* pPools) +{ + for(uint32_t poolIndex = 0; poolIndex < poolCount; ++poolIndex) + { + VmaPool pool = pPools[poolIndex]; + VMA_ASSERT(pool); + // Pools with algorithm other than default are not defragmented. + if(pool->m_BlockVector.GetAlgorithm() == 0) + { + VmaBlockVectorDefragmentationContext* pBlockVectorDefragCtx = VMA_NULL; + + for(size_t i = m_CustomPoolContexts.size(); i--; ) + { + if(m_CustomPoolContexts[i]->GetCustomPool() == pool) + { + pBlockVectorDefragCtx = m_CustomPoolContexts[i]; + break; + } + } + + if(!pBlockVectorDefragCtx) + { + pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( + m_hAllocator, + pool, + &pool->m_BlockVector, + m_CurrFrameIndex); + m_CustomPoolContexts.push_back(pBlockVectorDefragCtx); + } + + pBlockVectorDefragCtx->AddAll(); + } + } +} + +void VmaDefragmentationContext_T::AddAllocations( + uint32_t allocationCount, + const VmaAllocation* pAllocations, + VkBool32* pAllocationsChanged) +{ + // Dispatch pAllocations among defragmentators. Create them when necessary. + for(uint32_t allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + const VmaAllocation hAlloc = pAllocations[allocIndex]; + VMA_ASSERT(hAlloc); + // DedicatedAlloc cannot be defragmented. + if((hAlloc->GetType() == VmaAllocation_T::ALLOCATION_TYPE_BLOCK) && + // Lost allocation cannot be defragmented. + (hAlloc->GetLastUseFrameIndex() != VMA_FRAME_INDEX_LOST)) + { + VmaBlockVectorDefragmentationContext* pBlockVectorDefragCtx = VMA_NULL; + + const VmaPool hAllocPool = hAlloc->GetBlock()->GetParentPool(); + // This allocation belongs to custom pool. + if(hAllocPool != VK_NULL_HANDLE) + { + // Pools with algorithm other than default are not defragmented. + if(hAllocPool->m_BlockVector.GetAlgorithm() == 0) + { + for(size_t i = m_CustomPoolContexts.size(); i--; ) + { + if(m_CustomPoolContexts[i]->GetCustomPool() == hAllocPool) + { + pBlockVectorDefragCtx = m_CustomPoolContexts[i]; + break; + } + } + if(!pBlockVectorDefragCtx) + { + pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( + m_hAllocator, + hAllocPool, + &hAllocPool->m_BlockVector, + m_CurrFrameIndex); + m_CustomPoolContexts.push_back(pBlockVectorDefragCtx); + } + } + } + // This allocation belongs to default pool. + else + { + const uint32_t memTypeIndex = hAlloc->GetMemoryTypeIndex(); + pBlockVectorDefragCtx = m_DefaultPoolContexts[memTypeIndex]; + if(!pBlockVectorDefragCtx) + { + pBlockVectorDefragCtx = vma_new(m_hAllocator, VmaBlockVectorDefragmentationContext)( + m_hAllocator, + VMA_NULL, // hCustomPool + m_hAllocator->m_pBlockVectors[memTypeIndex], + m_CurrFrameIndex); + m_DefaultPoolContexts[memTypeIndex] = pBlockVectorDefragCtx; + } + } + + if(pBlockVectorDefragCtx) + { + VkBool32* const pChanged = (pAllocationsChanged != VMA_NULL) ? + &pAllocationsChanged[allocIndex] : VMA_NULL; + pBlockVectorDefragCtx->AddAllocation(hAlloc, pChanged); + } + } + } +} + +VkResult VmaDefragmentationContext_T::Defragment( + VkDeviceSize maxCpuBytesToMove, uint32_t maxCpuAllocationsToMove, + VkDeviceSize maxGpuBytesToMove, uint32_t maxGpuAllocationsToMove, + VkCommandBuffer commandBuffer, VmaDefragmentationStats* pStats, VmaDefragmentationFlags flags) +{ + if(pStats) + { + memset(pStats, 0, sizeof(VmaDefragmentationStats)); + } + + if(flags & VMA_DEFRAGMENTATION_FLAG_INCREMENTAL) + { + // For incremental defragmetnations, we just earmark how much we can move + // The real meat is in the defragmentation steps + m_MaxCpuBytesToMove = maxCpuBytesToMove; + m_MaxCpuAllocationsToMove = maxCpuAllocationsToMove; + + m_MaxGpuBytesToMove = maxGpuBytesToMove; + m_MaxGpuAllocationsToMove = maxGpuAllocationsToMove; + + if(m_MaxCpuBytesToMove == 0 && m_MaxCpuAllocationsToMove == 0 && + m_MaxGpuBytesToMove == 0 && m_MaxGpuAllocationsToMove == 0) + return VK_SUCCESS; + + return VK_NOT_READY; + } + + if(commandBuffer == VK_NULL_HANDLE) + { + maxGpuBytesToMove = 0; + maxGpuAllocationsToMove = 0; + } + + VkResult res = VK_SUCCESS; + + // Process default pools. + for(uint32_t memTypeIndex = 0; + memTypeIndex < m_hAllocator->GetMemoryTypeCount() && res >= VK_SUCCESS; + ++memTypeIndex) + { + VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; + if(pBlockVectorCtx) + { + VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); + pBlockVectorCtx->GetBlockVector()->Defragment( + pBlockVectorCtx, + pStats, flags, + maxCpuBytesToMove, maxCpuAllocationsToMove, + maxGpuBytesToMove, maxGpuAllocationsToMove, + commandBuffer); + if(pBlockVectorCtx->res != VK_SUCCESS) + { + res = pBlockVectorCtx->res; + } + } + } + + // Process custom pools. + for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); + customCtxIndex < customCtxCount && res >= VK_SUCCESS; + ++customCtxIndex) + { + VmaBlockVectorDefragmentationContext* pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; + VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); + pBlockVectorCtx->GetBlockVector()->Defragment( + pBlockVectorCtx, + pStats, flags, + maxCpuBytesToMove, maxCpuAllocationsToMove, + maxGpuBytesToMove, maxGpuAllocationsToMove, + commandBuffer); + if(pBlockVectorCtx->res != VK_SUCCESS) + { + res = pBlockVectorCtx->res; + } + } + + return res; +} + +VkResult VmaDefragmentationContext_T::DefragmentPassBegin(VmaDefragmentationPassInfo* pInfo) +{ + VmaDefragmentationPassMoveInfo* pCurrentMove = pInfo->pMoves; + uint32_t movesLeft = pInfo->moveCount; + + // Process default pools. + for(uint32_t memTypeIndex = 0; + memTypeIndex < m_hAllocator->GetMemoryTypeCount(); + ++memTypeIndex) + { + VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; + if(pBlockVectorCtx) + { + VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); + + if(!pBlockVectorCtx->hasDefragmentationPlan) + { + pBlockVectorCtx->GetBlockVector()->Defragment( + pBlockVectorCtx, + m_pStats, m_Flags, + m_MaxCpuBytesToMove, m_MaxCpuAllocationsToMove, + m_MaxGpuBytesToMove, m_MaxGpuAllocationsToMove, + VK_NULL_HANDLE); + + if(pBlockVectorCtx->res < VK_SUCCESS) + continue; + + pBlockVectorCtx->hasDefragmentationPlan = true; + } + + const uint32_t processed = pBlockVectorCtx->GetBlockVector()->ProcessDefragmentations( + pBlockVectorCtx, + pCurrentMove, movesLeft); + + movesLeft -= processed; + pCurrentMove += processed; + } + } + + // Process custom pools. + for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); + customCtxIndex < customCtxCount; + ++customCtxIndex) + { + VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; + VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); + + if(!pBlockVectorCtx->hasDefragmentationPlan) + { + pBlockVectorCtx->GetBlockVector()->Defragment( + pBlockVectorCtx, + m_pStats, m_Flags, + m_MaxCpuBytesToMove, m_MaxCpuAllocationsToMove, + m_MaxGpuBytesToMove, m_MaxGpuAllocationsToMove, + VK_NULL_HANDLE); + + if(pBlockVectorCtx->res < VK_SUCCESS) + continue; + + pBlockVectorCtx->hasDefragmentationPlan = true; + } + + const uint32_t processed = pBlockVectorCtx->GetBlockVector()->ProcessDefragmentations( + pBlockVectorCtx, + pCurrentMove, movesLeft); + + movesLeft -= processed; + pCurrentMove += processed; + } + + pInfo->moveCount = pInfo->moveCount - movesLeft; + + return VK_SUCCESS; +} +VkResult VmaDefragmentationContext_T::DefragmentPassEnd() +{ + VkResult res = VK_SUCCESS; + + // Process default pools. + for(uint32_t memTypeIndex = 0; + memTypeIndex < m_hAllocator->GetMemoryTypeCount(); + ++memTypeIndex) + { + VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_DefaultPoolContexts[memTypeIndex]; + if(pBlockVectorCtx) + { + VMA_ASSERT(pBlockVectorCtx->GetBlockVector()); + + if(!pBlockVectorCtx->hasDefragmentationPlan) + { + res = VK_NOT_READY; + continue; + } + + pBlockVectorCtx->GetBlockVector()->CommitDefragmentations( + pBlockVectorCtx, m_pStats); + + if(pBlockVectorCtx->defragmentationMoves.size() != pBlockVectorCtx->defragmentationMovesCommitted) + res = VK_NOT_READY; + } + } + + // Process custom pools. + for(size_t customCtxIndex = 0, customCtxCount = m_CustomPoolContexts.size(); + customCtxIndex < customCtxCount; + ++customCtxIndex) + { + VmaBlockVectorDefragmentationContext *pBlockVectorCtx = m_CustomPoolContexts[customCtxIndex]; + VMA_ASSERT(pBlockVectorCtx && pBlockVectorCtx->GetBlockVector()); + + if(!pBlockVectorCtx->hasDefragmentationPlan) + { + res = VK_NOT_READY; + continue; + } + + pBlockVectorCtx->GetBlockVector()->CommitDefragmentations( + pBlockVectorCtx, m_pStats); + + if(pBlockVectorCtx->defragmentationMoves.size() != pBlockVectorCtx->defragmentationMovesCommitted) + res = VK_NOT_READY; + } + + return res; +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaRecorder + +#if VMA_RECORDING_ENABLED + +VmaRecorder::VmaRecorder() : + m_UseMutex(true), + m_Flags(0), + m_File(VMA_NULL), + m_RecordingStartTime(std::chrono::high_resolution_clock::now()) +{ +} + +VkResult VmaRecorder::Init(const VmaRecordSettings& settings, bool useMutex) +{ + m_UseMutex = useMutex; + m_Flags = settings.flags; + +#if defined(_WIN32) + // Open file for writing. + errno_t err = fopen_s(&m_File, settings.pFilePath, "wb"); + + if(err != 0) + { + return VK_ERROR_INITIALIZATION_FAILED; + } +#else + // Open file for writing. + m_File = fopen(settings.pFilePath, "wb"); + + if(m_File == 0) + { + return VK_ERROR_INITIALIZATION_FAILED; + } +#endif + + // Write header. + fprintf(m_File, "%s\n", "Vulkan Memory Allocator,Calls recording"); + fprintf(m_File, "%s\n", "1,8"); + + return VK_SUCCESS; +} + +VmaRecorder::~VmaRecorder() +{ + if(m_File != VMA_NULL) + { + fclose(m_File); + } +} + +void VmaRecorder::RecordCreateAllocator(uint32_t frameIndex) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaCreateAllocator\n", callParams.threadId, callParams.time, frameIndex); + Flush(); +} + +void VmaRecorder::RecordDestroyAllocator(uint32_t frameIndex) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDestroyAllocator\n", callParams.threadId, callParams.time, frameIndex); + Flush(); +} + +void VmaRecorder::RecordCreatePool(uint32_t frameIndex, const VmaPoolCreateInfo& createInfo, VmaPool pool) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaCreatePool,%u,%u,%llu,%llu,%llu,%u,%p\n", callParams.threadId, callParams.time, frameIndex, + createInfo.memoryTypeIndex, + createInfo.flags, + createInfo.blockSize, + (uint64_t)createInfo.minBlockCount, + (uint64_t)createInfo.maxBlockCount, + createInfo.frameInUseCount, + pool); + Flush(); +} + +void VmaRecorder::RecordDestroyPool(uint32_t frameIndex, VmaPool pool) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDestroyPool,%p\n", callParams.threadId, callParams.time, frameIndex, + pool); + Flush(); +} + +void VmaRecorder::RecordAllocateMemory(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(createInfo.flags, createInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemory,%llu,%llu,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + vkMemReq.size, + vkMemReq.alignment, + vkMemReq.memoryTypeBits, + createInfo.flags, + createInfo.usage, + createInfo.requiredFlags, + createInfo.preferredFlags, + createInfo.memoryTypeBits, + createInfo.pool, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordAllocateMemoryPages(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + const VmaAllocationCreateInfo& createInfo, + uint64_t allocationCount, + const VmaAllocation* pAllocations) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(createInfo.flags, createInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryPages,%llu,%llu,%u,%u,%u,%u,%u,%u,%p,", callParams.threadId, callParams.time, frameIndex, + vkMemReq.size, + vkMemReq.alignment, + vkMemReq.memoryTypeBits, + createInfo.flags, + createInfo.usage, + createInfo.requiredFlags, + createInfo.preferredFlags, + createInfo.memoryTypeBits, + createInfo.pool); + PrintPointerList(allocationCount, pAllocations); + fprintf(m_File, ",%s\n", userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordAllocateMemoryForBuffer(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(createInfo.flags, createInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryForBuffer,%llu,%llu,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + vkMemReq.size, + vkMemReq.alignment, + vkMemReq.memoryTypeBits, + requiresDedicatedAllocation ? 1 : 0, + prefersDedicatedAllocation ? 1 : 0, + createInfo.flags, + createInfo.usage, + createInfo.requiredFlags, + createInfo.preferredFlags, + createInfo.memoryTypeBits, + createInfo.pool, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordAllocateMemoryForImage(uint32_t frameIndex, + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + const VmaAllocationCreateInfo& createInfo, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(createInfo.flags, createInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaAllocateMemoryForImage,%llu,%llu,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + vkMemReq.size, + vkMemReq.alignment, + vkMemReq.memoryTypeBits, + requiresDedicatedAllocation ? 1 : 0, + prefersDedicatedAllocation ? 1 : 0, + createInfo.flags, + createInfo.usage, + createInfo.requiredFlags, + createInfo.preferredFlags, + createInfo.memoryTypeBits, + createInfo.pool, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordFreeMemory(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaFreeMemory,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordFreeMemoryPages(uint32_t frameIndex, + uint64_t allocationCount, + const VmaAllocation* pAllocations) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaFreeMemoryPages,", callParams.threadId, callParams.time, frameIndex); + PrintPointerList(allocationCount, pAllocations); + fprintf(m_File, "\n"); + Flush(); +} + +void VmaRecorder::RecordSetAllocationUserData(uint32_t frameIndex, + VmaAllocation allocation, + const void* pUserData) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr( + allocation->IsUserDataString() ? VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT : 0, + pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaSetAllocationUserData,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordCreateLostAllocation(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaCreateLostAllocation,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordMapMemory(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaMapMemory,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordUnmapMemory(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaUnmapMemory,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordFlushAllocation(uint32_t frameIndex, + VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaFlushAllocation,%p,%llu,%llu\n", callParams.threadId, callParams.time, frameIndex, + allocation, + offset, + size); + Flush(); +} + +void VmaRecorder::RecordInvalidateAllocation(uint32_t frameIndex, + VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaInvalidateAllocation,%p,%llu,%llu\n", callParams.threadId, callParams.time, frameIndex, + allocation, + offset, + size); + Flush(); +} + +void VmaRecorder::RecordCreateBuffer(uint32_t frameIndex, + const VkBufferCreateInfo& bufCreateInfo, + const VmaAllocationCreateInfo& allocCreateInfo, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(allocCreateInfo.flags, allocCreateInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaCreateBuffer,%u,%llu,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + bufCreateInfo.flags, + bufCreateInfo.size, + bufCreateInfo.usage, + bufCreateInfo.sharingMode, + allocCreateInfo.flags, + allocCreateInfo.usage, + allocCreateInfo.requiredFlags, + allocCreateInfo.preferredFlags, + allocCreateInfo.memoryTypeBits, + allocCreateInfo.pool, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordCreateImage(uint32_t frameIndex, + const VkImageCreateInfo& imageCreateInfo, + const VmaAllocationCreateInfo& allocCreateInfo, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + UserDataString userDataStr(allocCreateInfo.flags, allocCreateInfo.pUserData); + fprintf(m_File, "%u,%.3f,%u,vmaCreateImage,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%p,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + imageCreateInfo.flags, + imageCreateInfo.imageType, + imageCreateInfo.format, + imageCreateInfo.extent.width, + imageCreateInfo.extent.height, + imageCreateInfo.extent.depth, + imageCreateInfo.mipLevels, + imageCreateInfo.arrayLayers, + imageCreateInfo.samples, + imageCreateInfo.tiling, + imageCreateInfo.usage, + imageCreateInfo.sharingMode, + imageCreateInfo.initialLayout, + allocCreateInfo.flags, + allocCreateInfo.usage, + allocCreateInfo.requiredFlags, + allocCreateInfo.preferredFlags, + allocCreateInfo.memoryTypeBits, + allocCreateInfo.pool, + allocation, + userDataStr.GetString()); + Flush(); +} + +void VmaRecorder::RecordDestroyBuffer(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDestroyBuffer,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordDestroyImage(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDestroyImage,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordTouchAllocation(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaTouchAllocation,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordGetAllocationInfo(uint32_t frameIndex, + VmaAllocation allocation) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaGetAllocationInfo,%p\n", callParams.threadId, callParams.time, frameIndex, + allocation); + Flush(); +} + +void VmaRecorder::RecordMakePoolAllocationsLost(uint32_t frameIndex, + VmaPool pool) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaMakePoolAllocationsLost,%p\n", callParams.threadId, callParams.time, frameIndex, + pool); + Flush(); +} + +void VmaRecorder::RecordDefragmentationBegin(uint32_t frameIndex, + const VmaDefragmentationInfo2& info, + VmaDefragmentationContext ctx) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDefragmentationBegin,%u,", callParams.threadId, callParams.time, frameIndex, + info.flags); + PrintPointerList(info.allocationCount, info.pAllocations); + fprintf(m_File, ","); + PrintPointerList(info.poolCount, info.pPools); + fprintf(m_File, ",%llu,%u,%llu,%u,%p,%p\n", + info.maxCpuBytesToMove, + info.maxCpuAllocationsToMove, + info.maxGpuBytesToMove, + info.maxGpuAllocationsToMove, + info.commandBuffer, + ctx); + Flush(); +} + +void VmaRecorder::RecordDefragmentationEnd(uint32_t frameIndex, + VmaDefragmentationContext ctx) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaDefragmentationEnd,%p\n", callParams.threadId, callParams.time, frameIndex, + ctx); + Flush(); +} + +void VmaRecorder::RecordSetPoolName(uint32_t frameIndex, + VmaPool pool, + const char* name) +{ + CallParams callParams; + GetBasicParams(callParams); + + VmaMutexLock lock(m_FileMutex, m_UseMutex); + fprintf(m_File, "%u,%.3f,%u,vmaSetPoolName,%p,%s\n", callParams.threadId, callParams.time, frameIndex, + pool, name != VMA_NULL ? name : ""); + Flush(); +} + +VmaRecorder::UserDataString::UserDataString(VmaAllocationCreateFlags allocFlags, const void* pUserData) +{ + if(pUserData != VMA_NULL) + { + if((allocFlags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0) + { + m_Str = (const char*)pUserData; + } + else + { + // If VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT is not specified, convert the string's memory address to a string and store it. + snprintf(m_PtrStr, 17, "%p", pUserData); + m_Str = m_PtrStr; + } + } + else + { + m_Str = ""; + } +} + +void VmaRecorder::WriteConfiguration( + const VkPhysicalDeviceProperties& devProps, + const VkPhysicalDeviceMemoryProperties& memProps, + uint32_t vulkanApiVersion, + bool dedicatedAllocationExtensionEnabled, + bool bindMemory2ExtensionEnabled, + bool memoryBudgetExtensionEnabled, + bool deviceCoherentMemoryExtensionEnabled) +{ + fprintf(m_File, "Config,Begin\n"); + + fprintf(m_File, "VulkanApiVersion,%u,%u\n", VK_VERSION_MAJOR(vulkanApiVersion), VK_VERSION_MINOR(vulkanApiVersion)); + + fprintf(m_File, "PhysicalDevice,apiVersion,%u\n", devProps.apiVersion); + fprintf(m_File, "PhysicalDevice,driverVersion,%u\n", devProps.driverVersion); + fprintf(m_File, "PhysicalDevice,vendorID,%u\n", devProps.vendorID); + fprintf(m_File, "PhysicalDevice,deviceID,%u\n", devProps.deviceID); + fprintf(m_File, "PhysicalDevice,deviceType,%u\n", devProps.deviceType); + fprintf(m_File, "PhysicalDevice,deviceName,%s\n", devProps.deviceName); + + fprintf(m_File, "PhysicalDeviceLimits,maxMemoryAllocationCount,%u\n", devProps.limits.maxMemoryAllocationCount); + fprintf(m_File, "PhysicalDeviceLimits,bufferImageGranularity,%llu\n", devProps.limits.bufferImageGranularity); + fprintf(m_File, "PhysicalDeviceLimits,nonCoherentAtomSize,%llu\n", devProps.limits.nonCoherentAtomSize); + + fprintf(m_File, "PhysicalDeviceMemory,HeapCount,%u\n", memProps.memoryHeapCount); + for(uint32_t i = 0; i < memProps.memoryHeapCount; ++i) + { + fprintf(m_File, "PhysicalDeviceMemory,Heap,%u,size,%llu\n", i, memProps.memoryHeaps[i].size); + fprintf(m_File, "PhysicalDeviceMemory,Heap,%u,flags,%u\n", i, memProps.memoryHeaps[i].flags); + } + fprintf(m_File, "PhysicalDeviceMemory,TypeCount,%u\n", memProps.memoryTypeCount); + for(uint32_t i = 0; i < memProps.memoryTypeCount; ++i) + { + fprintf(m_File, "PhysicalDeviceMemory,Type,%u,heapIndex,%u\n", i, memProps.memoryTypes[i].heapIndex); + fprintf(m_File, "PhysicalDeviceMemory,Type,%u,propertyFlags,%u\n", i, memProps.memoryTypes[i].propertyFlags); + } + + fprintf(m_File, "Extension,VK_KHR_dedicated_allocation,%u\n", dedicatedAllocationExtensionEnabled ? 1 : 0); + fprintf(m_File, "Extension,VK_KHR_bind_memory2,%u\n", bindMemory2ExtensionEnabled ? 1 : 0); + fprintf(m_File, "Extension,VK_EXT_memory_budget,%u\n", memoryBudgetExtensionEnabled ? 1 : 0); + fprintf(m_File, "Extension,VK_AMD_device_coherent_memory,%u\n", deviceCoherentMemoryExtensionEnabled ? 1 : 0); + + fprintf(m_File, "Macro,VMA_DEBUG_ALWAYS_DEDICATED_MEMORY,%u\n", VMA_DEBUG_ALWAYS_DEDICATED_MEMORY ? 1 : 0); + fprintf(m_File, "Macro,VMA_DEBUG_ALIGNMENT,%llu\n", (VkDeviceSize)VMA_DEBUG_ALIGNMENT); + fprintf(m_File, "Macro,VMA_DEBUG_MARGIN,%llu\n", (VkDeviceSize)VMA_DEBUG_MARGIN); + fprintf(m_File, "Macro,VMA_DEBUG_INITIALIZE_ALLOCATIONS,%u\n", VMA_DEBUG_INITIALIZE_ALLOCATIONS ? 1 : 0); + fprintf(m_File, "Macro,VMA_DEBUG_DETECT_CORRUPTION,%u\n", VMA_DEBUG_DETECT_CORRUPTION ? 1 : 0); + fprintf(m_File, "Macro,VMA_DEBUG_GLOBAL_MUTEX,%u\n", VMA_DEBUG_GLOBAL_MUTEX ? 1 : 0); + fprintf(m_File, "Macro,VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY,%llu\n", (VkDeviceSize)VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY); + fprintf(m_File, "Macro,VMA_SMALL_HEAP_MAX_SIZE,%llu\n", (VkDeviceSize)VMA_SMALL_HEAP_MAX_SIZE); + fprintf(m_File, "Macro,VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE,%llu\n", (VkDeviceSize)VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE); + + fprintf(m_File, "Config,End\n"); +} + +void VmaRecorder::GetBasicParams(CallParams& outParams) +{ + #if defined(_WIN32) + outParams.threadId = GetCurrentThreadId(); + #else + // Use C++11 features to get thread id and convert it to uint32_t. + // There is room for optimization since sstream is quite slow. + // Is there a better way to convert std::this_thread::get_id() to uint32_t? + std::thread::id thread_id = std::this_thread::get_id(); + stringstream thread_id_to_string_converter; + thread_id_to_string_converter << thread_id; + string thread_id_as_string = thread_id_to_string_converter.str(); + outParams.threadId = static_cast(std::stoi(thread_id_as_string.c_str())); + #endif + + auto current_time = std::chrono::high_resolution_clock::now(); + + outParams.time = std::chrono::duration(current_time - m_RecordingStartTime).count(); +} + +void VmaRecorder::PrintPointerList(uint64_t count, const VmaAllocation* pItems) +{ + if(count) + { + fprintf(m_File, "%p", pItems[0]); + for(uint64_t i = 1; i < count; ++i) + { + fprintf(m_File, " %p", pItems[i]); + } + } +} + +void VmaRecorder::Flush() +{ + if((m_Flags & VMA_RECORD_FLUSH_AFTER_CALL_BIT) != 0) + { + fflush(m_File); + } +} + +#endif // #if VMA_RECORDING_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +// VmaAllocationObjectAllocator + +VmaAllocationObjectAllocator::VmaAllocationObjectAllocator(const VkAllocationCallbacks* pAllocationCallbacks) : + m_Allocator(pAllocationCallbacks, 1024) +{ +} + +template VmaAllocation VmaAllocationObjectAllocator::Allocate(Types... args) +{ + VmaMutexLock mutexLock(m_Mutex); + return m_Allocator.Alloc(std::forward(args)...); +} + +void VmaAllocationObjectAllocator::Free(VmaAllocation hAlloc) +{ + VmaMutexLock mutexLock(m_Mutex); + m_Allocator.Free(hAlloc); +} + +//////////////////////////////////////////////////////////////////////////////// +// VmaAllocator_T + +VmaAllocator_T::VmaAllocator_T(const VmaAllocatorCreateInfo* pCreateInfo) : + m_UseMutex((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXTERNALLY_SYNCHRONIZED_BIT) == 0), + m_VulkanApiVersion(pCreateInfo->vulkanApiVersion != 0 ? pCreateInfo->vulkanApiVersion : VK_API_VERSION_1_0), + m_UseKhrDedicatedAllocation((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0), + m_UseKhrBindMemory2((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0), + m_UseExtMemoryBudget((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0), + m_UseAmdDeviceCoherentMemory((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_AMD_DEVICE_COHERENT_MEMORY_BIT) != 0), + m_UseKhrBufferDeviceAddress((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT) != 0), + m_hDevice(pCreateInfo->device), + m_hInstance(pCreateInfo->instance), + m_AllocationCallbacksSpecified(pCreateInfo->pAllocationCallbacks != VMA_NULL), + m_AllocationCallbacks(pCreateInfo->pAllocationCallbacks ? + *pCreateInfo->pAllocationCallbacks : VmaEmptyAllocationCallbacks), + m_AllocationObjectAllocator(&m_AllocationCallbacks), + m_HeapSizeLimitMask(0), + m_PreferredLargeHeapBlockSize(0), + m_PhysicalDevice(pCreateInfo->physicalDevice), + m_CurrentFrameIndex(0), + m_GpuDefragmentationMemoryTypeBits(UINT32_MAX), + m_Pools(VmaStlAllocator(GetAllocationCallbacks())), + m_NextPoolId(0), + m_GlobalMemoryTypeBits(UINT32_MAX) +#if VMA_RECORDING_ENABLED + ,m_pRecorder(VMA_NULL) +#endif +{ + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + m_UseKhrDedicatedAllocation = false; + m_UseKhrBindMemory2 = false; + } + + if(VMA_DEBUG_DETECT_CORRUPTION) + { + // Needs to be multiply of uint32_t size because we are going to write VMA_CORRUPTION_DETECTION_MAGIC_VALUE to it. + VMA_ASSERT(VMA_DEBUG_MARGIN % sizeof(uint32_t) == 0); + } + + VMA_ASSERT(pCreateInfo->physicalDevice && pCreateInfo->device && pCreateInfo->instance); + + if(m_VulkanApiVersion < VK_MAKE_VERSION(1, 1, 0)) + { +#if !(VMA_DEDICATED_ALLOCATION) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT) != 0) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_DEDICATED_ALLOCATION_BIT set but required extensions are disabled by preprocessor macros."); + } +#endif +#if !(VMA_BIND_MEMORY2) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT) != 0) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_KHR_BIND_MEMORY2_BIT set but required extension is disabled by preprocessor macros."); + } +#endif + } +#if !(VMA_MEMORY_BUDGET) + if((pCreateInfo->flags & VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT) != 0) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_EXT_MEMORY_BUDGET_BIT set but required extension is disabled by preprocessor macros."); + } +#endif +#if !(VMA_BUFFER_DEVICE_ADDRESS) + if(m_UseKhrBufferDeviceAddress) + { + VMA_ASSERT(0 && "VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT is set but required extension or Vulkan 1.2 is not available in your Vulkan header or its support in VMA has been disabled by a preprocessor macro."); + } +#endif +#if VMA_VULKAN_VERSION < 1002000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 2, 0)) + { + VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_2 but required Vulkan version is disabled by preprocessor macros."); + } +#endif +#if VMA_VULKAN_VERSION < 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VMA_ASSERT(0 && "vulkanApiVersion >= VK_API_VERSION_1_1 but required Vulkan version is disabled by preprocessor macros."); + } +#endif + + memset(&m_DeviceMemoryCallbacks, 0 ,sizeof(m_DeviceMemoryCallbacks)); + memset(&m_PhysicalDeviceProperties, 0, sizeof(m_PhysicalDeviceProperties)); + memset(&m_MemProps, 0, sizeof(m_MemProps)); + + memset(&m_pBlockVectors, 0, sizeof(m_pBlockVectors)); + memset(&m_pDedicatedAllocations, 0, sizeof(m_pDedicatedAllocations)); + memset(&m_VulkanFunctions, 0, sizeof(m_VulkanFunctions)); + + if(pCreateInfo->pDeviceMemoryCallbacks != VMA_NULL) + { + m_DeviceMemoryCallbacks.pUserData = pCreateInfo->pDeviceMemoryCallbacks->pUserData; + m_DeviceMemoryCallbacks.pfnAllocate = pCreateInfo->pDeviceMemoryCallbacks->pfnAllocate; + m_DeviceMemoryCallbacks.pfnFree = pCreateInfo->pDeviceMemoryCallbacks->pfnFree; + } + + ImportVulkanFunctions(pCreateInfo->pVulkanFunctions); + + (*m_VulkanFunctions.vkGetPhysicalDeviceProperties)(m_PhysicalDevice, &m_PhysicalDeviceProperties); + (*m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties)(m_PhysicalDevice, &m_MemProps); + + VMA_ASSERT(VmaIsPow2(VMA_DEBUG_ALIGNMENT)); + VMA_ASSERT(VmaIsPow2(VMA_DEBUG_MIN_BUFFER_IMAGE_GRANULARITY)); + VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.bufferImageGranularity)); + VMA_ASSERT(VmaIsPow2(m_PhysicalDeviceProperties.limits.nonCoherentAtomSize)); + + m_PreferredLargeHeapBlockSize = (pCreateInfo->preferredLargeHeapBlockSize != 0) ? + pCreateInfo->preferredLargeHeapBlockSize : static_cast(VMA_DEFAULT_LARGE_HEAP_BLOCK_SIZE); + + m_GlobalMemoryTypeBits = CalculateGlobalMemoryTypeBits(); + + if(pCreateInfo->pHeapSizeLimit != VMA_NULL) + { + for(uint32_t heapIndex = 0; heapIndex < GetMemoryHeapCount(); ++heapIndex) + { + const VkDeviceSize limit = pCreateInfo->pHeapSizeLimit[heapIndex]; + if(limit != VK_WHOLE_SIZE) + { + m_HeapSizeLimitMask |= 1u << heapIndex; + if(limit < m_MemProps.memoryHeaps[heapIndex].size) + { + m_MemProps.memoryHeaps[heapIndex].size = limit; + } + } + } + } + + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(memTypeIndex); + + m_pBlockVectors[memTypeIndex] = vma_new(this, VmaBlockVector)( + this, + VK_NULL_HANDLE, // hParentPool + memTypeIndex, + preferredBlockSize, + 0, + SIZE_MAX, + GetBufferImageGranularity(), + pCreateInfo->frameInUseCount, + false, // explicitBlockSize + false); // linearAlgorithm + // No need to call m_pBlockVectors[memTypeIndex][blockVectorTypeIndex]->CreateMinBlocks here, + // becase minBlockCount is 0. + m_pDedicatedAllocations[memTypeIndex] = vma_new(this, AllocationVectorType)(VmaStlAllocator(GetAllocationCallbacks())); + + } +} + +VkResult VmaAllocator_T::Init(const VmaAllocatorCreateInfo* pCreateInfo) +{ + VkResult res = VK_SUCCESS; + + if(pCreateInfo->pRecordSettings != VMA_NULL && + !VmaStrIsEmpty(pCreateInfo->pRecordSettings->pFilePath)) + { +#if VMA_RECORDING_ENABLED + m_pRecorder = vma_new(this, VmaRecorder)(); + res = m_pRecorder->Init(*pCreateInfo->pRecordSettings, m_UseMutex); + if(res != VK_SUCCESS) + { + return res; + } + m_pRecorder->WriteConfiguration( + m_PhysicalDeviceProperties, + m_MemProps, + m_VulkanApiVersion, + m_UseKhrDedicatedAllocation, + m_UseKhrBindMemory2, + m_UseExtMemoryBudget, + m_UseAmdDeviceCoherentMemory); + m_pRecorder->RecordCreateAllocator(GetCurrentFrameIndex()); +#else + VMA_ASSERT(0 && "VmaAllocatorCreateInfo::pRecordSettings used, but not supported due to VMA_RECORDING_ENABLED not defined to 1."); + return VK_ERROR_FEATURE_NOT_PRESENT; +#endif + } + +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) + { + UpdateVulkanBudget(); + } +#endif // #if VMA_MEMORY_BUDGET + + return res; +} + +VmaAllocator_T::~VmaAllocator_T() +{ +#if VMA_RECORDING_ENABLED + if(m_pRecorder != VMA_NULL) + { + m_pRecorder->RecordDestroyAllocator(GetCurrentFrameIndex()); + vma_delete(this, m_pRecorder); + } +#endif + + VMA_ASSERT(m_Pools.empty()); + + for(size_t i = GetMemoryTypeCount(); i--; ) + { + if(m_pDedicatedAllocations[i] != VMA_NULL && !m_pDedicatedAllocations[i]->empty()) + { + VMA_ASSERT(0 && "Unfreed dedicated allocations found."); + } + + vma_delete(this, m_pDedicatedAllocations[i]); + vma_delete(this, m_pBlockVectors[i]); + } +} + +void VmaAllocator_T::ImportVulkanFunctions(const VmaVulkanFunctions* pVulkanFunctions) +{ +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 + ImportVulkanFunctions_Static(); +#endif + + if(pVulkanFunctions != VMA_NULL) + { + ImportVulkanFunctions_Custom(pVulkanFunctions); + } + +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + ImportVulkanFunctions_Dynamic(); +#endif + + ValidateVulkanFunctions(); +} + +#if VMA_STATIC_VULKAN_FUNCTIONS == 1 + +void VmaAllocator_T::ImportVulkanFunctions_Static() +{ + // Vulkan 1.0 + m_VulkanFunctions.vkGetPhysicalDeviceProperties = (PFN_vkGetPhysicalDeviceProperties)vkGetPhysicalDeviceProperties; + m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties = (PFN_vkGetPhysicalDeviceMemoryProperties)vkGetPhysicalDeviceMemoryProperties; + m_VulkanFunctions.vkAllocateMemory = (PFN_vkAllocateMemory)vkAllocateMemory; + m_VulkanFunctions.vkFreeMemory = (PFN_vkFreeMemory)vkFreeMemory; + m_VulkanFunctions.vkMapMemory = (PFN_vkMapMemory)vkMapMemory; + m_VulkanFunctions.vkUnmapMemory = (PFN_vkUnmapMemory)vkUnmapMemory; + m_VulkanFunctions.vkFlushMappedMemoryRanges = (PFN_vkFlushMappedMemoryRanges)vkFlushMappedMemoryRanges; + m_VulkanFunctions.vkInvalidateMappedMemoryRanges = (PFN_vkInvalidateMappedMemoryRanges)vkInvalidateMappedMemoryRanges; + m_VulkanFunctions.vkBindBufferMemory = (PFN_vkBindBufferMemory)vkBindBufferMemory; + m_VulkanFunctions.vkBindImageMemory = (PFN_vkBindImageMemory)vkBindImageMemory; + m_VulkanFunctions.vkGetBufferMemoryRequirements = (PFN_vkGetBufferMemoryRequirements)vkGetBufferMemoryRequirements; + m_VulkanFunctions.vkGetImageMemoryRequirements = (PFN_vkGetImageMemoryRequirements)vkGetImageMemoryRequirements; + m_VulkanFunctions.vkCreateBuffer = (PFN_vkCreateBuffer)vkCreateBuffer; + m_VulkanFunctions.vkDestroyBuffer = (PFN_vkDestroyBuffer)vkDestroyBuffer; + m_VulkanFunctions.vkCreateImage = (PFN_vkCreateImage)vkCreateImage; + m_VulkanFunctions.vkDestroyImage = (PFN_vkDestroyImage)vkDestroyImage; + m_VulkanFunctions.vkCmdCopyBuffer = (PFN_vkCmdCopyBuffer)vkCmdCopyBuffer; + + // Vulkan 1.1 +#if VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2)vkGetBufferMemoryRequirements2; + m_VulkanFunctions.vkGetImageMemoryRequirements2KHR = (PFN_vkGetImageMemoryRequirements2)vkGetImageMemoryRequirements2; + m_VulkanFunctions.vkBindBufferMemory2KHR = (PFN_vkBindBufferMemory2)vkBindBufferMemory2; + m_VulkanFunctions.vkBindImageMemory2KHR = (PFN_vkBindImageMemory2)vkBindImageMemory2; + m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR = (PFN_vkGetPhysicalDeviceMemoryProperties2)vkGetPhysicalDeviceMemoryProperties2; + } +#endif +} + +#endif // #if VMA_STATIC_VULKAN_FUNCTIONS == 1 + +void VmaAllocator_T::ImportVulkanFunctions_Custom(const VmaVulkanFunctions* pVulkanFunctions) +{ + VMA_ASSERT(pVulkanFunctions != VMA_NULL); + +#define VMA_COPY_IF_NOT_NULL(funcName) \ + if(pVulkanFunctions->funcName != VMA_NULL) m_VulkanFunctions.funcName = pVulkanFunctions->funcName; + + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceProperties); + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties); + VMA_COPY_IF_NOT_NULL(vkAllocateMemory); + VMA_COPY_IF_NOT_NULL(vkFreeMemory); + VMA_COPY_IF_NOT_NULL(vkMapMemory); + VMA_COPY_IF_NOT_NULL(vkUnmapMemory); + VMA_COPY_IF_NOT_NULL(vkFlushMappedMemoryRanges); + VMA_COPY_IF_NOT_NULL(vkInvalidateMappedMemoryRanges); + VMA_COPY_IF_NOT_NULL(vkBindBufferMemory); + VMA_COPY_IF_NOT_NULL(vkBindImageMemory); + VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements); + VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements); + VMA_COPY_IF_NOT_NULL(vkCreateBuffer); + VMA_COPY_IF_NOT_NULL(vkDestroyBuffer); + VMA_COPY_IF_NOT_NULL(vkCreateImage); + VMA_COPY_IF_NOT_NULL(vkDestroyImage); + VMA_COPY_IF_NOT_NULL(vkCmdCopyBuffer); + +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + VMA_COPY_IF_NOT_NULL(vkGetBufferMemoryRequirements2KHR); + VMA_COPY_IF_NOT_NULL(vkGetImageMemoryRequirements2KHR); +#endif + +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + VMA_COPY_IF_NOT_NULL(vkBindBufferMemory2KHR); + VMA_COPY_IF_NOT_NULL(vkBindImageMemory2KHR); +#endif + +#if VMA_MEMORY_BUDGET + VMA_COPY_IF_NOT_NULL(vkGetPhysicalDeviceMemoryProperties2KHR); +#endif + +#undef VMA_COPY_IF_NOT_NULL +} + +#if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + +void VmaAllocator_T::ImportVulkanFunctions_Dynamic() +{ +#define VMA_FETCH_INSTANCE_FUNC(memberName, functionPointerType, functionNameString) \ + if(m_VulkanFunctions.memberName == VMA_NULL) \ + m_VulkanFunctions.memberName = \ + (functionPointerType)vkGetInstanceProcAddr(m_hInstance, functionNameString); +#define VMA_FETCH_DEVICE_FUNC(memberName, functionPointerType, functionNameString) \ + if(m_VulkanFunctions.memberName == VMA_NULL) \ + m_VulkanFunctions.memberName = \ + (functionPointerType)vkGetDeviceProcAddr(m_hDevice, functionNameString); + + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceProperties, PFN_vkGetPhysicalDeviceProperties, "vkGetPhysicalDeviceProperties"); + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties, PFN_vkGetPhysicalDeviceMemoryProperties, "vkGetPhysicalDeviceMemoryProperties"); + VMA_FETCH_DEVICE_FUNC(vkAllocateMemory, PFN_vkAllocateMemory, "vkAllocateMemory"); + VMA_FETCH_DEVICE_FUNC(vkFreeMemory, PFN_vkFreeMemory, "vkFreeMemory"); + VMA_FETCH_DEVICE_FUNC(vkMapMemory, PFN_vkMapMemory, "vkMapMemory"); + VMA_FETCH_DEVICE_FUNC(vkUnmapMemory, PFN_vkUnmapMemory, "vkUnmapMemory"); + VMA_FETCH_DEVICE_FUNC(vkFlushMappedMemoryRanges, PFN_vkFlushMappedMemoryRanges, "vkFlushMappedMemoryRanges"); + VMA_FETCH_DEVICE_FUNC(vkInvalidateMappedMemoryRanges, PFN_vkInvalidateMappedMemoryRanges, "vkInvalidateMappedMemoryRanges"); + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory, PFN_vkBindBufferMemory, "vkBindBufferMemory"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory, PFN_vkBindImageMemory, "vkBindImageMemory"); + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements, PFN_vkGetBufferMemoryRequirements, "vkGetBufferMemoryRequirements"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements, PFN_vkGetImageMemoryRequirements, "vkGetImageMemoryRequirements"); + VMA_FETCH_DEVICE_FUNC(vkCreateBuffer, PFN_vkCreateBuffer, "vkCreateBuffer"); + VMA_FETCH_DEVICE_FUNC(vkDestroyBuffer, PFN_vkDestroyBuffer, "vkDestroyBuffer"); + VMA_FETCH_DEVICE_FUNC(vkCreateImage, PFN_vkCreateImage, "vkCreateImage"); + VMA_FETCH_DEVICE_FUNC(vkDestroyImage, PFN_vkDestroyImage, "vkDestroyImage"); + VMA_FETCH_DEVICE_FUNC(vkCmdCopyBuffer, PFN_vkCmdCopyBuffer, "vkCmdCopyBuffer"); + +#if VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2, "vkGetBufferMemoryRequirements2"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2, "vkGetImageMemoryRequirements2"); + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2, "vkBindBufferMemory2"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2, "vkBindImageMemory2"); + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2, "vkGetPhysicalDeviceMemoryProperties2"); + } +#endif + +#if VMA_DEDICATED_ALLOCATION + if(m_UseKhrDedicatedAllocation) + { + VMA_FETCH_DEVICE_FUNC(vkGetBufferMemoryRequirements2KHR, PFN_vkGetBufferMemoryRequirements2KHR, "vkGetBufferMemoryRequirements2KHR"); + VMA_FETCH_DEVICE_FUNC(vkGetImageMemoryRequirements2KHR, PFN_vkGetImageMemoryRequirements2KHR, "vkGetImageMemoryRequirements2KHR"); + } +#endif + +#if VMA_BIND_MEMORY2 + if(m_UseKhrBindMemory2) + { + VMA_FETCH_DEVICE_FUNC(vkBindBufferMemory2KHR, PFN_vkBindBufferMemory2KHR, "vkBindBufferMemory2KHR"); + VMA_FETCH_DEVICE_FUNC(vkBindImageMemory2KHR, PFN_vkBindImageMemory2KHR, "vkBindImageMemory2KHR"); + } +#endif // #if VMA_BIND_MEMORY2 + +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) + { + VMA_FETCH_INSTANCE_FUNC(vkGetPhysicalDeviceMemoryProperties2KHR, PFN_vkGetPhysicalDeviceMemoryProperties2KHR, "vkGetPhysicalDeviceMemoryProperties2KHR"); + } +#endif // #if VMA_MEMORY_BUDGET + +#undef VMA_FETCH_DEVICE_FUNC +#undef VMA_FETCH_INSTANCE_FUNC +} + +#endif // #if VMA_DYNAMIC_VULKAN_FUNCTIONS == 1 + +void VmaAllocator_T::ValidateVulkanFunctions() +{ + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceProperties != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkAllocateMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkFreeMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkMapMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkUnmapMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkFlushMappedMemoryRanges != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkInvalidateMappedMemoryRanges != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCreateBuffer != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkDestroyBuffer != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCreateImage != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkDestroyImage != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkCmdCopyBuffer != VMA_NULL); + +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrDedicatedAllocation) + { + VMA_ASSERT(m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkGetImageMemoryRequirements2KHR != VMA_NULL); + } +#endif + +#if VMA_BIND_MEMORY2 || VMA_VULKAN_VERSION >= 1001000 + if(m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0) || m_UseKhrBindMemory2) + { + VMA_ASSERT(m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL); + VMA_ASSERT(m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL); + } +#endif + +#if VMA_MEMORY_BUDGET || VMA_VULKAN_VERSION >= 1001000 + if(m_UseExtMemoryBudget || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VMA_ASSERT(m_VulkanFunctions.vkGetPhysicalDeviceMemoryProperties2KHR != VMA_NULL); + } +#endif +} + +VkDeviceSize VmaAllocator_T::CalcPreferredBlockSize(uint32_t memTypeIndex) +{ + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); + const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; + const bool isSmallHeap = heapSize <= VMA_SMALL_HEAP_MAX_SIZE; + return VmaAlignUp(isSmallHeap ? (heapSize / 8) : m_PreferredLargeHeapBlockSize, (VkDeviceSize)32); +} + +VkResult VmaAllocator_T::AllocateMemoryOfType( + VkDeviceSize size, + VkDeviceSize alignment, + bool dedicatedAllocation, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, + VkImage dedicatedImage, + const VmaAllocationCreateInfo& createInfo, + uint32_t memTypeIndex, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + VMA_ASSERT(pAllocations != VMA_NULL); + VMA_DEBUG_LOG(" AllocateMemory: MemoryTypeIndex=%u, AllocationCount=%zu, Size=%llu", memTypeIndex, allocationCount, size); + + VmaAllocationCreateInfo finalCreateInfo = createInfo; + + // If memory type is not HOST_VISIBLE, disable MAPPED. + if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && + (m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + finalCreateInfo.flags &= ~VMA_ALLOCATION_CREATE_MAPPED_BIT; + } + // If memory is lazily allocated, it should be always dedicated. + if(finalCreateInfo.usage == VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED) + { + finalCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; + } + + VmaBlockVector* const blockVector = m_pBlockVectors[memTypeIndex]; + VMA_ASSERT(blockVector); + + const VkDeviceSize preferredBlockSize = blockVector->GetPreferredBlockSize(); + bool preferDedicatedMemory = + VMA_DEBUG_ALWAYS_DEDICATED_MEMORY || + dedicatedAllocation || + // Heuristics: Allocate dedicated memory if requested size if greater than half of preferred block size. + size > preferredBlockSize / 2; + + if(preferDedicatedMemory && + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) == 0 && + finalCreateInfo.pool == VK_NULL_HANDLE) + { + finalCreateInfo.flags |= VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT; + } + + if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0) + { + if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + else + { + return AllocateDedicatedMemory( + size, + suballocType, + memTypeIndex, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, + finalCreateInfo.pUserData, + dedicatedBuffer, + dedicatedBufferUsage, + dedicatedImage, + allocationCount, + pAllocations); + } + } + else + { + VkResult res = blockVector->Allocate( + m_CurrentFrameIndex.load(), + size, + alignment, + finalCreateInfo, + suballocType, + allocationCount, + pAllocations); + if(res == VK_SUCCESS) + { + return res; + } + + // 5. Try dedicated memory. + if((finalCreateInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + else + { + res = AllocateDedicatedMemory( + size, + suballocType, + memTypeIndex, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_WITHIN_BUDGET_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0, + (finalCreateInfo.flags & VMA_ALLOCATION_CREATE_USER_DATA_COPY_STRING_BIT) != 0, + finalCreateInfo.pUserData, + dedicatedBuffer, + dedicatedBufferUsage, + dedicatedImage, + allocationCount, + pAllocations); + if(res == VK_SUCCESS) + { + // Succeeded: AllocateDedicatedMemory function already filld pMemory, nothing more to do here. + VMA_DEBUG_LOG(" Allocated as DedicatedMemory"); + return VK_SUCCESS; + } + else + { + // Everything failed: Return error code. + VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); + return res; + } + } + } +} + +VkResult VmaAllocator_T::AllocateDedicatedMemory( + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + bool withinBudget, + bool map, + bool isUserDataString, + void* pUserData, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, + VkImage dedicatedImage, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + VMA_ASSERT(allocationCount > 0 && pAllocations); + + if(withinBudget) + { + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); + VmaBudget heapBudget = {}; + GetBudget(&heapBudget, heapIndex, 1); + if(heapBudget.usage + size * allocationCount > heapBudget.budget) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + } + + VkMemoryAllocateInfo allocInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO }; + allocInfo.memoryTypeIndex = memTypeIndex; + allocInfo.allocationSize = size; + +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + VkMemoryDedicatedAllocateInfoKHR dedicatedAllocInfo = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR }; + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + if(dedicatedBuffer != VK_NULL_HANDLE) + { + VMA_ASSERT(dedicatedImage == VK_NULL_HANDLE); + dedicatedAllocInfo.buffer = dedicatedBuffer; + VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); + } + else if(dedicatedImage != VK_NULL_HANDLE) + { + dedicatedAllocInfo.image = dedicatedImage; + VmaPnextChainPushFront(&allocInfo, &dedicatedAllocInfo); + } + } +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + +#if VMA_BUFFER_DEVICE_ADDRESS + VkMemoryAllocateFlagsInfoKHR allocFlagsInfo = { VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_FLAGS_INFO_KHR }; + if(m_UseKhrBufferDeviceAddress) + { + bool canContainBufferWithDeviceAddress = true; + if(dedicatedBuffer != VK_NULL_HANDLE) + { + canContainBufferWithDeviceAddress = dedicatedBufferUsage == UINT32_MAX || // Usage flags unknown + (dedicatedBufferUsage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_EXT) != 0; + } + else if(dedicatedImage != VK_NULL_HANDLE) + { + canContainBufferWithDeviceAddress = false; + } + if(canContainBufferWithDeviceAddress) + { + allocFlagsInfo.flags = VK_MEMORY_ALLOCATE_DEVICE_ADDRESS_BIT_KHR; + VmaPnextChainPushFront(&allocInfo, &allocFlagsInfo); + } + } +#endif // #if VMA_BUFFER_DEVICE_ADDRESS + + size_t allocIndex; + VkResult res = VK_SUCCESS; + for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + res = AllocateDedicatedMemoryPage( + size, + suballocType, + memTypeIndex, + allocInfo, + map, + isUserDataString, + pUserData, + pAllocations + allocIndex); + if(res != VK_SUCCESS) + { + break; + } + } + + if(res == VK_SUCCESS) + { + // Register them in m_pDedicatedAllocations. + { + VmaMutexLockWrite lock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); + AllocationVectorType* pDedicatedAllocations = m_pDedicatedAllocations[memTypeIndex]; + VMA_ASSERT(pDedicatedAllocations); + for(allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + VmaVectorInsertSorted(*pDedicatedAllocations, pAllocations[allocIndex]); + } + } + + VMA_DEBUG_LOG(" Allocated DedicatedMemory Count=%zu, MemoryTypeIndex=#%u", allocationCount, memTypeIndex); + } + else + { + // Free all already created allocations. + while(allocIndex--) + { + VmaAllocation currAlloc = pAllocations[allocIndex]; + VkDeviceMemory hMemory = currAlloc->GetMemory(); + + /* + There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory + before vkFreeMemory. + + if(currAlloc->GetMappedData() != VMA_NULL) + { + (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); + } + */ + + FreeVulkanMemory(memTypeIndex, currAlloc->GetSize(), hMemory); + m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), currAlloc->GetSize()); + currAlloc->SetUserData(this, VMA_NULL); + m_AllocationObjectAllocator.Free(currAlloc); + } + + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + } + + return res; +} + +VkResult VmaAllocator_T::AllocateDedicatedMemoryPage( + VkDeviceSize size, + VmaSuballocationType suballocType, + uint32_t memTypeIndex, + const VkMemoryAllocateInfo& allocInfo, + bool map, + bool isUserDataString, + void* pUserData, + VmaAllocation* pAllocation) +{ + VkDeviceMemory hMemory = VK_NULL_HANDLE; + VkResult res = AllocateVulkanMemory(&allocInfo, &hMemory); + if(res < 0) + { + VMA_DEBUG_LOG(" vkAllocateMemory FAILED"); + return res; + } + + void* pMappedData = VMA_NULL; + if(map) + { + res = (*m_VulkanFunctions.vkMapMemory)( + m_hDevice, + hMemory, + 0, + VK_WHOLE_SIZE, + 0, + &pMappedData); + if(res < 0) + { + VMA_DEBUG_LOG(" vkMapMemory FAILED"); + FreeVulkanMemory(memTypeIndex, size, hMemory); + return res; + } + } + + *pAllocation = m_AllocationObjectAllocator.Allocate(m_CurrentFrameIndex.load(), isUserDataString); + (*pAllocation)->InitDedicatedAllocation(memTypeIndex, hMemory, suballocType, pMappedData, size); + (*pAllocation)->SetUserData(this, pUserData); + m_Budget.AddAllocation(MemoryTypeIndexToHeapIndex(memTypeIndex), size); + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + FillAllocation(*pAllocation, VMA_ALLOCATION_FILL_PATTERN_CREATED); + } + + return VK_SUCCESS; +} + +void VmaAllocator_T::GetBufferMemoryRequirements( + VkBuffer hBuffer, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const +{ +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VkBufferMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR }; + memReqInfo.buffer = hBuffer; + + VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; + + VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; + VmaPnextChainPushFront(&memReq2, &memDedicatedReq); + + (*m_VulkanFunctions.vkGetBufferMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + + memReq = memReq2.memoryRequirements; + requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); + prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); + } + else +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + { + (*m_VulkanFunctions.vkGetBufferMemoryRequirements)(m_hDevice, hBuffer, &memReq); + requiresDedicatedAllocation = false; + prefersDedicatedAllocation = false; + } +} + +void VmaAllocator_T::GetImageMemoryRequirements( + VkImage hImage, + VkMemoryRequirements& memReq, + bool& requiresDedicatedAllocation, + bool& prefersDedicatedAllocation) const +{ +#if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + if(m_UseKhrDedicatedAllocation || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) + { + VkImageMemoryRequirementsInfo2KHR memReqInfo = { VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR }; + memReqInfo.image = hImage; + + VkMemoryDedicatedRequirementsKHR memDedicatedReq = { VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR }; + + VkMemoryRequirements2KHR memReq2 = { VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR }; + VmaPnextChainPushFront(&memReq2, &memDedicatedReq); + + (*m_VulkanFunctions.vkGetImageMemoryRequirements2KHR)(m_hDevice, &memReqInfo, &memReq2); + + memReq = memReq2.memoryRequirements; + requiresDedicatedAllocation = (memDedicatedReq.requiresDedicatedAllocation != VK_FALSE); + prefersDedicatedAllocation = (memDedicatedReq.prefersDedicatedAllocation != VK_FALSE); + } + else +#endif // #if VMA_DEDICATED_ALLOCATION || VMA_VULKAN_VERSION >= 1001000 + { + (*m_VulkanFunctions.vkGetImageMemoryRequirements)(m_hDevice, hImage, &memReq); + requiresDedicatedAllocation = false; + prefersDedicatedAllocation = false; + } +} + +VkResult VmaAllocator_T::AllocateMemory( + const VkMemoryRequirements& vkMemReq, + bool requiresDedicatedAllocation, + bool prefersDedicatedAllocation, + VkBuffer dedicatedBuffer, + VkBufferUsageFlags dedicatedBufferUsage, + VkImage dedicatedImage, + const VmaAllocationCreateInfo& createInfo, + VmaSuballocationType suballocType, + size_t allocationCount, + VmaAllocation* pAllocations) +{ + memset(pAllocations, 0, sizeof(VmaAllocation) * allocationCount); + + VMA_ASSERT(VmaIsPow2(vkMemReq.alignment)); + + if(vkMemReq.size == 0) + { + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if((createInfo.flags & VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT) != 0 && + (createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT together with VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT makes no sense."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + if((createInfo.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && + (createInfo.flags & VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT) != 0) + { + VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_MAPPED_BIT together with VMA_ALLOCATION_CREATE_CAN_BECOME_LOST_BIT is invalid."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + if(requiresDedicatedAllocation) + { + if((createInfo.flags & VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT) != 0) + { + VMA_ASSERT(0 && "VMA_ALLOCATION_CREATE_NEVER_ALLOCATE_BIT specified while dedicated allocation is required."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + if(createInfo.pool != VK_NULL_HANDLE) + { + VMA_ASSERT(0 && "Pool specified while dedicated allocation is required."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + } + if((createInfo.pool != VK_NULL_HANDLE) && + ((createInfo.flags & (VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT)) != 0)) + { + VMA_ASSERT(0 && "Specifying VMA_ALLOCATION_CREATE_DEDICATED_MEMORY_BIT when pool != null is invalid."); + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + + if(createInfo.pool != VK_NULL_HANDLE) + { + const VkDeviceSize alignmentForPool = VMA_MAX( + vkMemReq.alignment, + GetMemoryTypeMinAlignment(createInfo.pool->m_BlockVector.GetMemoryTypeIndex())); + + VmaAllocationCreateInfo createInfoForPool = createInfo; + // If memory type is not HOST_VISIBLE, disable MAPPED. + if((createInfoForPool.flags & VMA_ALLOCATION_CREATE_MAPPED_BIT) != 0 && + (m_MemProps.memoryTypes[createInfo.pool->m_BlockVector.GetMemoryTypeIndex()].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + createInfoForPool.flags &= ~VMA_ALLOCATION_CREATE_MAPPED_BIT; + } + + return createInfo.pool->m_BlockVector.Allocate( + m_CurrentFrameIndex.load(), + vkMemReq.size, + alignmentForPool, + createInfoForPool, + suballocType, + allocationCount, + pAllocations); + } + else + { + // Bit mask of memory Vulkan types acceptable for this allocation. + uint32_t memoryTypeBits = vkMemReq.memoryTypeBits; + uint32_t memTypeIndex = UINT32_MAX; + VkResult res = vmaFindMemoryTypeIndex(this, memoryTypeBits, &createInfo, &memTypeIndex); + if(res == VK_SUCCESS) + { + VkDeviceSize alignmentForMemType = VMA_MAX( + vkMemReq.alignment, + GetMemoryTypeMinAlignment(memTypeIndex)); + + res = AllocateMemoryOfType( + vkMemReq.size, + alignmentForMemType, + requiresDedicatedAllocation || prefersDedicatedAllocation, + dedicatedBuffer, + dedicatedBufferUsage, + dedicatedImage, + createInfo, + memTypeIndex, + suballocType, + allocationCount, + pAllocations); + // Succeeded on first try. + if(res == VK_SUCCESS) + { + return res; + } + // Allocation from this memory type failed. Try other compatible memory types. + else + { + for(;;) + { + // Remove old memTypeIndex from list of possibilities. + memoryTypeBits &= ~(1u << memTypeIndex); + // Find alternative memTypeIndex. + res = vmaFindMemoryTypeIndex(this, memoryTypeBits, &createInfo, &memTypeIndex); + if(res == VK_SUCCESS) + { + alignmentForMemType = VMA_MAX( + vkMemReq.alignment, + GetMemoryTypeMinAlignment(memTypeIndex)); + + res = AllocateMemoryOfType( + vkMemReq.size, + alignmentForMemType, + requiresDedicatedAllocation || prefersDedicatedAllocation, + dedicatedBuffer, + dedicatedBufferUsage, + dedicatedImage, + createInfo, + memTypeIndex, + suballocType, + allocationCount, + pAllocations); + // Allocation from this alternative memory type succeeded. + if(res == VK_SUCCESS) + { + return res; + } + // else: Allocation from this memory type failed. Try next one - next loop iteration. + } + // No other matching memory type index could be found. + else + { + // Not returning res, which is VK_ERROR_FEATURE_NOT_PRESENT, because we already failed to allocate once. + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + } + } + } + // Can't find any single memory type maching requirements. res is VK_ERROR_FEATURE_NOT_PRESENT. + else + return res; + } +} + +void VmaAllocator_T::FreeMemory( + size_t allocationCount, + const VmaAllocation* pAllocations) +{ + VMA_ASSERT(pAllocations); + + for(size_t allocIndex = allocationCount; allocIndex--; ) + { + VmaAllocation allocation = pAllocations[allocIndex]; + + if(allocation != VK_NULL_HANDLE) + { + if(TouchAllocation(allocation)) + { + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS) + { + FillAllocation(allocation, VMA_ALLOCATION_FILL_PATTERN_DESTROYED); + } + + switch(allocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaBlockVector* pBlockVector = VMA_NULL; + VmaPool hPool = allocation->GetBlock()->GetParentPool(); + if(hPool != VK_NULL_HANDLE) + { + pBlockVector = &hPool->m_BlockVector; + } + else + { + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + pBlockVector = m_pBlockVectors[memTypeIndex]; + } + pBlockVector->Free(allocation); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + FreeDedicatedMemory(allocation); + break; + default: + VMA_ASSERT(0); + } + } + + // Do this regardless of whether the allocation is lost. Lost allocations still account to Budget.AllocationBytes. + m_Budget.RemoveAllocation(MemoryTypeIndexToHeapIndex(allocation->GetMemoryTypeIndex()), allocation->GetSize()); + allocation->SetUserData(this, VMA_NULL); + m_AllocationObjectAllocator.Free(allocation); + } + } +} + +VkResult VmaAllocator_T::ResizeAllocation( + const VmaAllocation alloc, + VkDeviceSize newSize) +{ + // This function is deprecated and so it does nothing. It's left for backward compatibility. + if(newSize == 0 || alloc->GetLastUseFrameIndex() == VMA_FRAME_INDEX_LOST) + { + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if(newSize == alloc->GetSize()) + { + return VK_SUCCESS; + } + return VK_ERROR_OUT_OF_POOL_MEMORY; +} + +void VmaAllocator_T::CalculateStats(VmaStats* pStats) +{ + // Initialize. + InitStatInfo(pStats->total); + for(size_t i = 0; i < VK_MAX_MEMORY_TYPES; ++i) + InitStatInfo(pStats->memoryType[i]); + for(size_t i = 0; i < VK_MAX_MEMORY_HEAPS; ++i) + InitStatInfo(pStats->memoryHeap[i]); + + // Process default pools. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; + VMA_ASSERT(pBlockVector); + pBlockVector->AddStats(pStats); + } + + // Process custom pools. + { + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + for(size_t poolIndex = 0, poolCount = m_Pools.size(); poolIndex < poolCount; ++poolIndex) + { + m_Pools[poolIndex]->m_BlockVector.AddStats(pStats); + } + } + + // Process dedicated allocations. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + const uint32_t memHeapIndex = MemoryTypeIndexToHeapIndex(memTypeIndex); + VmaMutexLockRead dedicatedAllocationsLock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); + AllocationVectorType* const pDedicatedAllocVector = m_pDedicatedAllocations[memTypeIndex]; + VMA_ASSERT(pDedicatedAllocVector); + for(size_t allocIndex = 0, allocCount = pDedicatedAllocVector->size(); allocIndex < allocCount; ++allocIndex) + { + VmaStatInfo allocationStatInfo; + (*pDedicatedAllocVector)[allocIndex]->DedicatedAllocCalcStatsInfo(allocationStatInfo); + VmaAddStatInfo(pStats->total, allocationStatInfo); + VmaAddStatInfo(pStats->memoryType[memTypeIndex], allocationStatInfo); + VmaAddStatInfo(pStats->memoryHeap[memHeapIndex], allocationStatInfo); + } + } + + // Postprocess. + VmaPostprocessCalcStatInfo(pStats->total); + for(size_t i = 0; i < GetMemoryTypeCount(); ++i) + VmaPostprocessCalcStatInfo(pStats->memoryType[i]); + for(size_t i = 0; i < GetMemoryHeapCount(); ++i) + VmaPostprocessCalcStatInfo(pStats->memoryHeap[i]); +} + +void VmaAllocator_T::GetBudget(VmaBudget* outBudget, uint32_t firstHeap, uint32_t heapCount) +{ +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) + { + if(m_Budget.m_OperationsSinceBudgetFetch < 30) + { + VmaMutexLockRead lockRead(m_Budget.m_BudgetMutex, m_UseMutex); + for(uint32_t i = 0; i < heapCount; ++i, ++outBudget) + { + const uint32_t heapIndex = firstHeap + i; + + outBudget->blockBytes = m_Budget.m_BlockBytes[heapIndex]; + outBudget->allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; + + if(m_Budget.m_VulkanUsage[heapIndex] + outBudget->blockBytes > m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]) + { + outBudget->usage = m_Budget.m_VulkanUsage[heapIndex] + + outBudget->blockBytes - m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; + } + else + { + outBudget->usage = 0; + } + + // Have to take MIN with heap size because explicit HeapSizeLimit is included in it. + outBudget->budget = VMA_MIN( + m_Budget.m_VulkanBudget[heapIndex], m_MemProps.memoryHeaps[heapIndex].size); + } + } + else + { + UpdateVulkanBudget(); // Outside of mutex lock + GetBudget(outBudget, firstHeap, heapCount); // Recursion + } + } + else +#endif + { + for(uint32_t i = 0; i < heapCount; ++i, ++outBudget) + { + const uint32_t heapIndex = firstHeap + i; + + outBudget->blockBytes = m_Budget.m_BlockBytes[heapIndex]; + outBudget->allocationBytes = m_Budget.m_AllocationBytes[heapIndex]; + + outBudget->usage = outBudget->blockBytes; + outBudget->budget = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. + } + } +} + +static const uint32_t VMA_VENDOR_ID_AMD = 4098; + +VkResult VmaAllocator_T::DefragmentationBegin( + const VmaDefragmentationInfo2& info, + VmaDefragmentationStats* pStats, + VmaDefragmentationContext* pContext) +{ + if(info.pAllocationsChanged != VMA_NULL) + { + memset(info.pAllocationsChanged, 0, info.allocationCount * sizeof(VkBool32)); + } + + *pContext = vma_new(this, VmaDefragmentationContext_T)( + this, m_CurrentFrameIndex.load(), info.flags, pStats); + + (*pContext)->AddPools(info.poolCount, info.pPools); + (*pContext)->AddAllocations( + info.allocationCount, info.pAllocations, info.pAllocationsChanged); + + VkResult res = (*pContext)->Defragment( + info.maxCpuBytesToMove, info.maxCpuAllocationsToMove, + info.maxGpuBytesToMove, info.maxGpuAllocationsToMove, + info.commandBuffer, pStats, info.flags); + + if(res != VK_NOT_READY) + { + vma_delete(this, *pContext); + *pContext = VMA_NULL; + } + + return res; +} + +VkResult VmaAllocator_T::DefragmentationEnd( + VmaDefragmentationContext context) +{ + vma_delete(this, context); + return VK_SUCCESS; +} + +VkResult VmaAllocator_T::DefragmentationPassBegin( + VmaDefragmentationPassInfo* pInfo, + VmaDefragmentationContext context) +{ + return context->DefragmentPassBegin(pInfo); +} +VkResult VmaAllocator_T::DefragmentationPassEnd( + VmaDefragmentationContext context) +{ + return context->DefragmentPassEnd(); + +} + +void VmaAllocator_T::GetAllocationInfo(VmaAllocation hAllocation, VmaAllocationInfo* pAllocationInfo) +{ + if(hAllocation->CanBecomeLost()) + { + /* + Warning: This is a carefully designed algorithm. + Do not modify unless you really know what you're doing :) + */ + const uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); + uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); + for(;;) + { + if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) + { + pAllocationInfo->memoryType = UINT32_MAX; + pAllocationInfo->deviceMemory = VK_NULL_HANDLE; + pAllocationInfo->offset = 0; + pAllocationInfo->size = hAllocation->GetSize(); + pAllocationInfo->pMappedData = VMA_NULL; + pAllocationInfo->pUserData = hAllocation->GetUserData(); + return; + } + else if(localLastUseFrameIndex == localCurrFrameIndex) + { + pAllocationInfo->memoryType = hAllocation->GetMemoryTypeIndex(); + pAllocationInfo->deviceMemory = hAllocation->GetMemory(); + pAllocationInfo->offset = hAllocation->GetOffset(); + pAllocationInfo->size = hAllocation->GetSize(); + pAllocationInfo->pMappedData = VMA_NULL; + pAllocationInfo->pUserData = hAllocation->GetUserData(); + return; + } + else // Last use time earlier than current time. + { + if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) + { + localLastUseFrameIndex = localCurrFrameIndex; + } + } + } + } + else + { +#if VMA_STATS_STRING_ENABLED + uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); + uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); + for(;;) + { + VMA_ASSERT(localLastUseFrameIndex != VMA_FRAME_INDEX_LOST); + if(localLastUseFrameIndex == localCurrFrameIndex) + { + break; + } + else // Last use time earlier than current time. + { + if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) + { + localLastUseFrameIndex = localCurrFrameIndex; + } + } + } +#endif + + pAllocationInfo->memoryType = hAllocation->GetMemoryTypeIndex(); + pAllocationInfo->deviceMemory = hAllocation->GetMemory(); + pAllocationInfo->offset = hAllocation->GetOffset(); + pAllocationInfo->size = hAllocation->GetSize(); + pAllocationInfo->pMappedData = hAllocation->GetMappedData(); + pAllocationInfo->pUserData = hAllocation->GetUserData(); + } +} + +bool VmaAllocator_T::TouchAllocation(VmaAllocation hAllocation) +{ + // This is a stripped-down version of VmaAllocator_T::GetAllocationInfo. + if(hAllocation->CanBecomeLost()) + { + uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); + uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); + for(;;) + { + if(localLastUseFrameIndex == VMA_FRAME_INDEX_LOST) + { + return false; + } + else if(localLastUseFrameIndex == localCurrFrameIndex) + { + return true; + } + else // Last use time earlier than current time. + { + if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) + { + localLastUseFrameIndex = localCurrFrameIndex; + } + } + } + } + else + { +#if VMA_STATS_STRING_ENABLED + uint32_t localCurrFrameIndex = m_CurrentFrameIndex.load(); + uint32_t localLastUseFrameIndex = hAllocation->GetLastUseFrameIndex(); + for(;;) + { + VMA_ASSERT(localLastUseFrameIndex != VMA_FRAME_INDEX_LOST); + if(localLastUseFrameIndex == localCurrFrameIndex) + { + break; + } + else // Last use time earlier than current time. + { + if(hAllocation->CompareExchangeLastUseFrameIndex(localLastUseFrameIndex, localCurrFrameIndex)) + { + localLastUseFrameIndex = localCurrFrameIndex; + } + } + } +#endif + + return true; + } +} + +VkResult VmaAllocator_T::CreatePool(const VmaPoolCreateInfo* pCreateInfo, VmaPool* pPool) +{ + VMA_DEBUG_LOG(" CreatePool: MemoryTypeIndex=%u, flags=%u", pCreateInfo->memoryTypeIndex, pCreateInfo->flags); + + VmaPoolCreateInfo newCreateInfo = *pCreateInfo; + + if(newCreateInfo.maxBlockCount == 0) + { + newCreateInfo.maxBlockCount = SIZE_MAX; + } + if(newCreateInfo.minBlockCount > newCreateInfo.maxBlockCount) + { + return VK_ERROR_INITIALIZATION_FAILED; + } + // Memory type index out of range or forbidden. + if(pCreateInfo->memoryTypeIndex >= GetMemoryTypeCount() || + ((1u << pCreateInfo->memoryTypeIndex) & m_GlobalMemoryTypeBits) == 0) + { + return VK_ERROR_FEATURE_NOT_PRESENT; + } + + const VkDeviceSize preferredBlockSize = CalcPreferredBlockSize(newCreateInfo.memoryTypeIndex); + + *pPool = vma_new(this, VmaPool_T)(this, newCreateInfo, preferredBlockSize); + + VkResult res = (*pPool)->m_BlockVector.CreateMinBlocks(); + if(res != VK_SUCCESS) + { + vma_delete(this, *pPool); + *pPool = VMA_NULL; + return res; + } + + // Add to m_Pools. + { + VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); + (*pPool)->SetId(m_NextPoolId++); + VmaVectorInsertSorted(m_Pools, *pPool); + } + + return VK_SUCCESS; +} + +void VmaAllocator_T::DestroyPool(VmaPool pool) +{ + // Remove from m_Pools. + { + VmaMutexLockWrite lock(m_PoolsMutex, m_UseMutex); + bool success = VmaVectorRemoveSorted(m_Pools, pool); + VMA_ASSERT(success && "Pool not found in Allocator."); + } + + vma_delete(this, pool); +} + +void VmaAllocator_T::GetPoolStats(VmaPool pool, VmaPoolStats* pPoolStats) +{ + pool->m_BlockVector.GetPoolStats(pPoolStats); +} + +void VmaAllocator_T::SetCurrentFrameIndex(uint32_t frameIndex) +{ + m_CurrentFrameIndex.store(frameIndex); + +#if VMA_MEMORY_BUDGET + if(m_UseExtMemoryBudget) + { + UpdateVulkanBudget(); + } +#endif // #if VMA_MEMORY_BUDGET +} + +void VmaAllocator_T::MakePoolAllocationsLost( + VmaPool hPool, + size_t* pLostAllocationCount) +{ + hPool->m_BlockVector.MakePoolAllocationsLost( + m_CurrentFrameIndex.load(), + pLostAllocationCount); +} + +VkResult VmaAllocator_T::CheckPoolCorruption(VmaPool hPool) +{ + return hPool->m_BlockVector.CheckCorruption(); +} + +VkResult VmaAllocator_T::CheckCorruption(uint32_t memoryTypeBits) +{ + VkResult finalRes = VK_ERROR_FEATURE_NOT_PRESENT; + + // Process default pools. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + if(((1u << memTypeIndex) & memoryTypeBits) != 0) + { + VmaBlockVector* const pBlockVector = m_pBlockVectors[memTypeIndex]; + VMA_ASSERT(pBlockVector); + VkResult localRes = pBlockVector->CheckCorruption(); + switch(localRes) + { + case VK_ERROR_FEATURE_NOT_PRESENT: + break; + case VK_SUCCESS: + finalRes = VK_SUCCESS; + break; + default: + return localRes; + } + } + } + + // Process custom pools. + { + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + for(size_t poolIndex = 0, poolCount = m_Pools.size(); poolIndex < poolCount; ++poolIndex) + { + if(((1u << m_Pools[poolIndex]->m_BlockVector.GetMemoryTypeIndex()) & memoryTypeBits) != 0) + { + VkResult localRes = m_Pools[poolIndex]->m_BlockVector.CheckCorruption(); + switch(localRes) + { + case VK_ERROR_FEATURE_NOT_PRESENT: + break; + case VK_SUCCESS: + finalRes = VK_SUCCESS; + break; + default: + return localRes; + } + } + } + } + + return finalRes; +} + +void VmaAllocator_T::CreateLostAllocation(VmaAllocation* pAllocation) +{ + *pAllocation = m_AllocationObjectAllocator.Allocate(VMA_FRAME_INDEX_LOST, false); + (*pAllocation)->InitLost(); +} + +VkResult VmaAllocator_T::AllocateVulkanMemory(const VkMemoryAllocateInfo* pAllocateInfo, VkDeviceMemory* pMemory) +{ + const uint32_t heapIndex = MemoryTypeIndexToHeapIndex(pAllocateInfo->memoryTypeIndex); + + // HeapSizeLimit is in effect for this heap. + if((m_HeapSizeLimitMask & (1u << heapIndex)) != 0) + { + const VkDeviceSize heapSize = m_MemProps.memoryHeaps[heapIndex].size; + VkDeviceSize blockBytes = m_Budget.m_BlockBytes[heapIndex]; + for(;;) + { + const VkDeviceSize blockBytesAfterAllocation = blockBytes + pAllocateInfo->allocationSize; + if(blockBytesAfterAllocation > heapSize) + { + return VK_ERROR_OUT_OF_DEVICE_MEMORY; + } + if(m_Budget.m_BlockBytes[heapIndex].compare_exchange_strong(blockBytes, blockBytesAfterAllocation)) + { + break; + } + } + } + else + { + m_Budget.m_BlockBytes[heapIndex] += pAllocateInfo->allocationSize; + } + + // VULKAN CALL vkAllocateMemory. + VkResult res = (*m_VulkanFunctions.vkAllocateMemory)(m_hDevice, pAllocateInfo, GetAllocationCallbacks(), pMemory); + + if(res == VK_SUCCESS) + { +#if VMA_MEMORY_BUDGET + ++m_Budget.m_OperationsSinceBudgetFetch; +#endif + + // Informative callback. + if(m_DeviceMemoryCallbacks.pfnAllocate != VMA_NULL) + { + (*m_DeviceMemoryCallbacks.pfnAllocate)(this, pAllocateInfo->memoryTypeIndex, *pMemory, pAllocateInfo->allocationSize, m_DeviceMemoryCallbacks.pUserData); + } + } + else + { + m_Budget.m_BlockBytes[heapIndex] -= pAllocateInfo->allocationSize; + } + + return res; +} + +void VmaAllocator_T::FreeVulkanMemory(uint32_t memoryType, VkDeviceSize size, VkDeviceMemory hMemory) +{ + // Informative callback. + if(m_DeviceMemoryCallbacks.pfnFree != VMA_NULL) + { + (*m_DeviceMemoryCallbacks.pfnFree)(this, memoryType, hMemory, size, m_DeviceMemoryCallbacks.pUserData); + } + + // VULKAN CALL vkFreeMemory. + (*m_VulkanFunctions.vkFreeMemory)(m_hDevice, hMemory, GetAllocationCallbacks()); + + m_Budget.m_BlockBytes[MemoryTypeIndexToHeapIndex(memoryType)] -= size; +} + +VkResult VmaAllocator_T::BindVulkanBuffer( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkBuffer buffer, + const void* pNext) +{ + if(pNext != VMA_NULL) + { +#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && + m_VulkanFunctions.vkBindBufferMemory2KHR != VMA_NULL) + { + VkBindBufferMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_BUFFER_MEMORY_INFO_KHR }; + bindBufferMemoryInfo.pNext = pNext; + bindBufferMemoryInfo.buffer = buffer; + bindBufferMemoryInfo.memory = memory; + bindBufferMemoryInfo.memoryOffset = memoryOffset; + return (*m_VulkanFunctions.vkBindBufferMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); + } + else +#endif // #if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + { + return VK_ERROR_EXTENSION_NOT_PRESENT; + } + } + else + { + return (*m_VulkanFunctions.vkBindBufferMemory)(m_hDevice, buffer, memory, memoryOffset); + } +} + +VkResult VmaAllocator_T::BindVulkanImage( + VkDeviceMemory memory, + VkDeviceSize memoryOffset, + VkImage image, + const void* pNext) +{ + if(pNext != VMA_NULL) + { +#if VMA_VULKAN_VERSION >= 1001000 || VMA_BIND_MEMORY2 + if((m_UseKhrBindMemory2 || m_VulkanApiVersion >= VK_MAKE_VERSION(1, 1, 0)) && + m_VulkanFunctions.vkBindImageMemory2KHR != VMA_NULL) + { + VkBindImageMemoryInfoKHR bindBufferMemoryInfo = { VK_STRUCTURE_TYPE_BIND_IMAGE_MEMORY_INFO_KHR }; + bindBufferMemoryInfo.pNext = pNext; + bindBufferMemoryInfo.image = image; + bindBufferMemoryInfo.memory = memory; + bindBufferMemoryInfo.memoryOffset = memoryOffset; + return (*m_VulkanFunctions.vkBindImageMemory2KHR)(m_hDevice, 1, &bindBufferMemoryInfo); + } + else +#endif // #if VMA_BIND_MEMORY2 + { + return VK_ERROR_EXTENSION_NOT_PRESENT; + } + } + else + { + return (*m_VulkanFunctions.vkBindImageMemory)(m_hDevice, image, memory, memoryOffset); + } +} + +VkResult VmaAllocator_T::Map(VmaAllocation hAllocation, void** ppData) +{ + if(hAllocation->CanBecomeLost()) + { + return VK_ERROR_MEMORY_MAP_FAILED; + } + + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + char *pBytes = VMA_NULL; + VkResult res = pBlock->Map(this, 1, (void**)&pBytes); + if(res == VK_SUCCESS) + { + *ppData = pBytes + (ptrdiff_t)hAllocation->GetOffset(); + hAllocation->BlockAllocMap(); + } + return res; + } + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + return hAllocation->DedicatedAllocMap(this, ppData); + default: + VMA_ASSERT(0); + return VK_ERROR_MEMORY_MAP_FAILED; + } +} + +void VmaAllocator_T::Unmap(VmaAllocation hAllocation) +{ + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + hAllocation->BlockAllocUnmap(); + pBlock->Unmap(this, 1); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + hAllocation->DedicatedAllocUnmap(this); + break; + default: + VMA_ASSERT(0); + } +} + +VkResult VmaAllocator_T::BindBufferMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkBuffer hBuffer, + const void* pNext) +{ + VkResult res = VK_SUCCESS; + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + res = BindVulkanBuffer(hAllocation->GetMemory(), allocationLocalOffset, hBuffer, pNext); + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* const pBlock = hAllocation->GetBlock(); + VMA_ASSERT(pBlock && "Binding buffer to allocation that doesn't belong to any block. Is the allocation lost?"); + res = pBlock->BindBufferMemory(this, hAllocation, allocationLocalOffset, hBuffer, pNext); + break; + } + default: + VMA_ASSERT(0); + } + return res; +} + +VkResult VmaAllocator_T::BindImageMemory( + VmaAllocation hAllocation, + VkDeviceSize allocationLocalOffset, + VkImage hImage, + const void* pNext) +{ + VkResult res = VK_SUCCESS; + switch(hAllocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + res = BindVulkanImage(hAllocation->GetMemory(), allocationLocalOffset, hImage, pNext); + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + VmaDeviceMemoryBlock* pBlock = hAllocation->GetBlock(); + VMA_ASSERT(pBlock && "Binding image to allocation that doesn't belong to any block. Is the allocation lost?"); + res = pBlock->BindImageMemory(this, hAllocation, allocationLocalOffset, hImage, pNext); + break; + } + default: + VMA_ASSERT(0); + } + return res; +} + +VkResult VmaAllocator_T::FlushOrInvalidateAllocation( + VmaAllocation hAllocation, + VkDeviceSize offset, VkDeviceSize size, + VMA_CACHE_OPERATION op) +{ + VkResult res = VK_SUCCESS; + + VkMappedMemoryRange memRange = {}; + if(GetFlushOrInvalidateRange(hAllocation, offset, size, memRange)) + { + switch(op) + { + case VMA_CACHE_FLUSH: + res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, 1, &memRange); + break; + case VMA_CACHE_INVALIDATE: + res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, 1, &memRange); + break; + default: + VMA_ASSERT(0); + } + } + // else: Just ignore this call. + return res; +} + +VkResult VmaAllocator_T::FlushOrInvalidateAllocations( + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, const VkDeviceSize* sizes, + VMA_CACHE_OPERATION op) +{ + typedef VmaStlAllocator RangeAllocator; + typedef VmaSmallVector RangeVector; + RangeVector ranges = RangeVector(RangeAllocator(GetAllocationCallbacks())); + + for(uint32_t allocIndex = 0; allocIndex < allocationCount; ++allocIndex) + { + const VmaAllocation alloc = allocations[allocIndex]; + const VkDeviceSize offset = offsets != VMA_NULL ? offsets[allocIndex] : 0; + const VkDeviceSize size = sizes != VMA_NULL ? sizes[allocIndex] : VK_WHOLE_SIZE; + VkMappedMemoryRange newRange; + if(GetFlushOrInvalidateRange(alloc, offset, size, newRange)) + { + ranges.push_back(newRange); + } + } + + VkResult res = VK_SUCCESS; + if(!ranges.empty()) + { + switch(op) + { + case VMA_CACHE_FLUSH: + res = (*GetVulkanFunctions().vkFlushMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); + break; + case VMA_CACHE_INVALIDATE: + res = (*GetVulkanFunctions().vkInvalidateMappedMemoryRanges)(m_hDevice, (uint32_t)ranges.size(), ranges.data()); + break; + default: + VMA_ASSERT(0); + } + } + // else: Just ignore this call. + return res; +} + +void VmaAllocator_T::FreeDedicatedMemory(const VmaAllocation allocation) +{ + VMA_ASSERT(allocation && allocation->GetType() == VmaAllocation_T::ALLOCATION_TYPE_DEDICATED); + + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + { + VmaMutexLockWrite lock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); + AllocationVectorType* const pDedicatedAllocations = m_pDedicatedAllocations[memTypeIndex]; + VMA_ASSERT(pDedicatedAllocations); + bool success = VmaVectorRemoveSorted(*pDedicatedAllocations, allocation); + VMA_ASSERT(success); + } + + VkDeviceMemory hMemory = allocation->GetMemory(); + + /* + There is no need to call this, because Vulkan spec allows to skip vkUnmapMemory + before vkFreeMemory. + + if(allocation->GetMappedData() != VMA_NULL) + { + (*m_VulkanFunctions.vkUnmapMemory)(m_hDevice, hMemory); + } + */ + + FreeVulkanMemory(memTypeIndex, allocation->GetSize(), hMemory); + + VMA_DEBUG_LOG(" Freed DedicatedMemory MemoryTypeIndex=%u", memTypeIndex); +} + +uint32_t VmaAllocator_T::CalculateGpuDefragmentationMemoryTypeBits() const +{ + VkBufferCreateInfo dummyBufCreateInfo; + VmaFillGpuDefragmentationBufferCreateInfo(dummyBufCreateInfo); + + uint32_t memoryTypeBits = 0; + + // Create buffer. + VkBuffer buf = VK_NULL_HANDLE; + VkResult res = (*GetVulkanFunctions().vkCreateBuffer)( + m_hDevice, &dummyBufCreateInfo, GetAllocationCallbacks(), &buf); + if(res == VK_SUCCESS) + { + // Query for supported memory types. + VkMemoryRequirements memReq; + (*GetVulkanFunctions().vkGetBufferMemoryRequirements)(m_hDevice, buf, &memReq); + memoryTypeBits = memReq.memoryTypeBits; + + // Destroy buffer. + (*GetVulkanFunctions().vkDestroyBuffer)(m_hDevice, buf, GetAllocationCallbacks()); + } + + return memoryTypeBits; +} + +uint32_t VmaAllocator_T::CalculateGlobalMemoryTypeBits() const +{ + // Make sure memory information is already fetched. + VMA_ASSERT(GetMemoryTypeCount() > 0); + + uint32_t memoryTypeBits = UINT32_MAX; + + if(!m_UseAmdDeviceCoherentMemory) + { + // Exclude memory types that have VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD. + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + if((m_MemProps.memoryTypes[memTypeIndex].propertyFlags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) != 0) + { + memoryTypeBits &= ~(1u << memTypeIndex); + } + } + } + + return memoryTypeBits; +} + +bool VmaAllocator_T::GetFlushOrInvalidateRange( + VmaAllocation allocation, + VkDeviceSize offset, VkDeviceSize size, + VkMappedMemoryRange& outRange) const +{ + const uint32_t memTypeIndex = allocation->GetMemoryTypeIndex(); + if(size > 0 && IsMemoryTypeNonCoherent(memTypeIndex)) + { + const VkDeviceSize nonCoherentAtomSize = m_PhysicalDeviceProperties.limits.nonCoherentAtomSize; + const VkDeviceSize allocationSize = allocation->GetSize(); + VMA_ASSERT(offset <= allocationSize); + + outRange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; + outRange.pNext = VMA_NULL; + outRange.memory = allocation->GetMemory(); + + switch(allocation->GetType()) + { + case VmaAllocation_T::ALLOCATION_TYPE_DEDICATED: + outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); + if(size == VK_WHOLE_SIZE) + { + outRange.size = allocationSize - outRange.offset; + } + else + { + VMA_ASSERT(offset + size <= allocationSize); + outRange.size = VMA_MIN( + VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize), + allocationSize - outRange.offset); + } + break; + case VmaAllocation_T::ALLOCATION_TYPE_BLOCK: + { + // 1. Still within this allocation. + outRange.offset = VmaAlignDown(offset, nonCoherentAtomSize); + if(size == VK_WHOLE_SIZE) + { + size = allocationSize - offset; + } + else + { + VMA_ASSERT(offset + size <= allocationSize); + } + outRange.size = VmaAlignUp(size + (offset - outRange.offset), nonCoherentAtomSize); + + // 2. Adjust to whole block. + const VkDeviceSize allocationOffset = allocation->GetOffset(); + VMA_ASSERT(allocationOffset % nonCoherentAtomSize == 0); + const VkDeviceSize blockSize = allocation->GetBlock()->m_pMetadata->GetSize(); + outRange.offset += allocationOffset; + outRange.size = VMA_MIN(outRange.size, blockSize - outRange.offset); + + break; + } + default: + VMA_ASSERT(0); + } + return true; + } + return false; +} + +#if VMA_MEMORY_BUDGET + +void VmaAllocator_T::UpdateVulkanBudget() +{ + VMA_ASSERT(m_UseExtMemoryBudget); + + VkPhysicalDeviceMemoryProperties2KHR memProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PROPERTIES_2_KHR }; + + VkPhysicalDeviceMemoryBudgetPropertiesEXT budgetProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_BUDGET_PROPERTIES_EXT }; + VmaPnextChainPushFront(&memProps, &budgetProps); + + GetVulkanFunctions().vkGetPhysicalDeviceMemoryProperties2KHR(m_PhysicalDevice, &memProps); + + { + VmaMutexLockWrite lockWrite(m_Budget.m_BudgetMutex, m_UseMutex); + + for(uint32_t heapIndex = 0; heapIndex < GetMemoryHeapCount(); ++heapIndex) + { + m_Budget.m_VulkanUsage[heapIndex] = budgetProps.heapUsage[heapIndex]; + m_Budget.m_VulkanBudget[heapIndex] = budgetProps.heapBudget[heapIndex]; + m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] = m_Budget.m_BlockBytes[heapIndex].load(); + + // Some bugged drivers return the budget incorrectly, e.g. 0 or much bigger than heap size. + if(m_Budget.m_VulkanBudget[heapIndex] == 0) + { + m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size * 8 / 10; // 80% heuristics. + } + else if(m_Budget.m_VulkanBudget[heapIndex] > m_MemProps.memoryHeaps[heapIndex].size) + { + m_Budget.m_VulkanBudget[heapIndex] = m_MemProps.memoryHeaps[heapIndex].size; + } + if(m_Budget.m_VulkanUsage[heapIndex] == 0 && m_Budget.m_BlockBytesAtBudgetFetch[heapIndex] > 0) + { + m_Budget.m_VulkanUsage[heapIndex] = m_Budget.m_BlockBytesAtBudgetFetch[heapIndex]; + } + } + m_Budget.m_OperationsSinceBudgetFetch = 0; + } +} + +#endif // #if VMA_MEMORY_BUDGET + +void VmaAllocator_T::FillAllocation(const VmaAllocation hAllocation, uint8_t pattern) +{ + if(VMA_DEBUG_INITIALIZE_ALLOCATIONS && + !hAllocation->CanBecomeLost() && + (m_MemProps.memoryTypes[hAllocation->GetMemoryTypeIndex()].propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) + { + void* pData = VMA_NULL; + VkResult res = Map(hAllocation, &pData); + if(res == VK_SUCCESS) + { + memset(pData, (int)pattern, (size_t)hAllocation->GetSize()); + FlushOrInvalidateAllocation(hAllocation, 0, VK_WHOLE_SIZE, VMA_CACHE_FLUSH); + Unmap(hAllocation); + } + else + { + VMA_ASSERT(0 && "VMA_DEBUG_INITIALIZE_ALLOCATIONS is enabled, but couldn't map memory to fill allocation."); + } + } +} + +uint32_t VmaAllocator_T::GetGpuDefragmentationMemoryTypeBits() +{ + uint32_t memoryTypeBits = m_GpuDefragmentationMemoryTypeBits.load(); + if(memoryTypeBits == UINT32_MAX) + { + memoryTypeBits = CalculateGpuDefragmentationMemoryTypeBits(); + m_GpuDefragmentationMemoryTypeBits.store(memoryTypeBits); + } + return memoryTypeBits; +} + +#if VMA_STATS_STRING_ENABLED + +void VmaAllocator_T::PrintDetailedMap(VmaJsonWriter& json) +{ + bool dedicatedAllocationsStarted = false; + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + VmaMutexLockRead dedicatedAllocationsLock(m_DedicatedAllocationsMutex[memTypeIndex], m_UseMutex); + AllocationVectorType* const pDedicatedAllocVector = m_pDedicatedAllocations[memTypeIndex]; + VMA_ASSERT(pDedicatedAllocVector); + if(pDedicatedAllocVector->empty() == false) + { + if(dedicatedAllocationsStarted == false) + { + dedicatedAllocationsStarted = true; + json.WriteString("DedicatedAllocations"); + json.BeginObject(); + } + + json.BeginString("Type "); + json.ContinueString(memTypeIndex); + json.EndString(); + + json.BeginArray(); + + for(size_t i = 0; i < pDedicatedAllocVector->size(); ++i) + { + json.BeginObject(true); + const VmaAllocation hAlloc = (*pDedicatedAllocVector)[i]; + hAlloc->PrintParameters(json); + json.EndObject(); + } + + json.EndArray(); + } + } + if(dedicatedAllocationsStarted) + { + json.EndObject(); + } + + { + bool allocationsStarted = false; + for(uint32_t memTypeIndex = 0; memTypeIndex < GetMemoryTypeCount(); ++memTypeIndex) + { + if(m_pBlockVectors[memTypeIndex]->IsEmpty() == false) + { + if(allocationsStarted == false) + { + allocationsStarted = true; + json.WriteString("DefaultPools"); + json.BeginObject(); + } + + json.BeginString("Type "); + json.ContinueString(memTypeIndex); + json.EndString(); + + m_pBlockVectors[memTypeIndex]->PrintDetailedMap(json); + } + } + if(allocationsStarted) + { + json.EndObject(); + } + } + + // Custom pools + { + VmaMutexLockRead lock(m_PoolsMutex, m_UseMutex); + const size_t poolCount = m_Pools.size(); + if(poolCount > 0) + { + json.WriteString("Pools"); + json.BeginObject(); + for(size_t poolIndex = 0; poolIndex < poolCount; ++poolIndex) + { + json.BeginString(); + json.ContinueString(m_Pools[poolIndex]->GetId()); + json.EndString(); + + m_Pools[poolIndex]->m_BlockVector.PrintDetailedMap(json); + } + json.EndObject(); + } + } +} + +#endif // #if VMA_STATS_STRING_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +// Public interface + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateAllocator( + const VmaAllocatorCreateInfo* pCreateInfo, + VmaAllocator* pAllocator) +{ + VMA_ASSERT(pCreateInfo && pAllocator); + VMA_ASSERT(pCreateInfo->vulkanApiVersion == 0 || + (VK_VERSION_MAJOR(pCreateInfo->vulkanApiVersion) == 1 && VK_VERSION_MINOR(pCreateInfo->vulkanApiVersion) <= 2)); + VMA_DEBUG_LOG("vmaCreateAllocator"); + *pAllocator = vma_new(pCreateInfo->pAllocationCallbacks, VmaAllocator_T)(pCreateInfo); + return (*pAllocator)->Init(pCreateInfo); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyAllocator( + VmaAllocator allocator) +{ + if(allocator != VK_NULL_HANDLE) + { + VMA_DEBUG_LOG("vmaDestroyAllocator"); + VkAllocationCallbacks allocationCallbacks = allocator->m_AllocationCallbacks; + vma_delete(&allocationCallbacks, allocator); + } +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocatorInfo(VmaAllocator allocator, VmaAllocatorInfo* pAllocatorInfo) +{ + VMA_ASSERT(allocator && pAllocatorInfo); + pAllocatorInfo->instance = allocator->m_hInstance; + pAllocatorInfo->physicalDevice = allocator->GetPhysicalDevice(); + pAllocatorInfo->device = allocator->m_hDevice; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetPhysicalDeviceProperties( + VmaAllocator allocator, + const VkPhysicalDeviceProperties **ppPhysicalDeviceProperties) +{ + VMA_ASSERT(allocator && ppPhysicalDeviceProperties); + *ppPhysicalDeviceProperties = &allocator->m_PhysicalDeviceProperties; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryProperties( + VmaAllocator allocator, + const VkPhysicalDeviceMemoryProperties** ppPhysicalDeviceMemoryProperties) +{ + VMA_ASSERT(allocator && ppPhysicalDeviceMemoryProperties); + *ppPhysicalDeviceMemoryProperties = &allocator->m_MemProps; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetMemoryTypeProperties( + VmaAllocator allocator, + uint32_t memoryTypeIndex, + VkMemoryPropertyFlags* pFlags) +{ + VMA_ASSERT(allocator && pFlags); + VMA_ASSERT(memoryTypeIndex < allocator->GetMemoryTypeCount()); + *pFlags = allocator->m_MemProps.memoryTypes[memoryTypeIndex].propertyFlags; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaSetCurrentFrameIndex( + VmaAllocator allocator, + uint32_t frameIndex) +{ + VMA_ASSERT(allocator); + VMA_ASSERT(frameIndex != VMA_FRAME_INDEX_LOST); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocator->SetCurrentFrameIndex(frameIndex); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaCalculateStats( + VmaAllocator allocator, + VmaStats* pStats) +{ + VMA_ASSERT(allocator && pStats); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + allocator->CalculateStats(pStats); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetBudget( + VmaAllocator allocator, + VmaBudget* pBudget) +{ + VMA_ASSERT(allocator && pBudget); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + allocator->GetBudget(pBudget, 0, allocator->GetMemoryHeapCount()); +} + +#if VMA_STATS_STRING_ENABLED + +VMA_CALL_PRE void VMA_CALL_POST vmaBuildStatsString( + VmaAllocator allocator, + char** ppStatsString, + VkBool32 detailedMap) +{ + VMA_ASSERT(allocator && ppStatsString); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VmaStringBuilder sb(allocator); + { + VmaJsonWriter json(allocator->GetAllocationCallbacks(), sb); + json.BeginObject(); + + VmaBudget budget[VK_MAX_MEMORY_HEAPS]; + allocator->GetBudget(budget, 0, allocator->GetMemoryHeapCount()); + + VmaStats stats; + allocator->CalculateStats(&stats); + + json.WriteString("Total"); + VmaPrintStatInfo(json, stats.total); + + for(uint32_t heapIndex = 0; heapIndex < allocator->GetMemoryHeapCount(); ++heapIndex) + { + json.BeginString("Heap "); + json.ContinueString(heapIndex); + json.EndString(); + json.BeginObject(); + + json.WriteString("Size"); + json.WriteNumber(allocator->m_MemProps.memoryHeaps[heapIndex].size); + + json.WriteString("Flags"); + json.BeginArray(true); + if((allocator->m_MemProps.memoryHeaps[heapIndex].flags & VK_MEMORY_HEAP_DEVICE_LOCAL_BIT) != 0) + { + json.WriteString("DEVICE_LOCAL"); + } + json.EndArray(); + + json.WriteString("Budget"); + json.BeginObject(); + { + json.WriteString("BlockBytes"); + json.WriteNumber(budget[heapIndex].blockBytes); + json.WriteString("AllocationBytes"); + json.WriteNumber(budget[heapIndex].allocationBytes); + json.WriteString("Usage"); + json.WriteNumber(budget[heapIndex].usage); + json.WriteString("Budget"); + json.WriteNumber(budget[heapIndex].budget); + } + json.EndObject(); + + if(stats.memoryHeap[heapIndex].blockCount > 0) + { + json.WriteString("Stats"); + VmaPrintStatInfo(json, stats.memoryHeap[heapIndex]); + } + + for(uint32_t typeIndex = 0; typeIndex < allocator->GetMemoryTypeCount(); ++typeIndex) + { + if(allocator->MemoryTypeIndexToHeapIndex(typeIndex) == heapIndex) + { + json.BeginString("Type "); + json.ContinueString(typeIndex); + json.EndString(); + + json.BeginObject(); + + json.WriteString("Flags"); + json.BeginArray(true); + VkMemoryPropertyFlags flags = allocator->m_MemProps.memoryTypes[typeIndex].propertyFlags; + if((flags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) != 0) + { + json.WriteString("DEVICE_LOCAL"); + } + if((flags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) != 0) + { + json.WriteString("HOST_VISIBLE"); + } + if((flags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT) != 0) + { + json.WriteString("HOST_COHERENT"); + } + if((flags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT) != 0) + { + json.WriteString("HOST_CACHED"); + } + if((flags & VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT) != 0) + { + json.WriteString("LAZILY_ALLOCATED"); + } + if((flags & VK_MEMORY_PROPERTY_PROTECTED_BIT) != 0) + { + json.WriteString(" PROTECTED"); + } + if((flags & VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY) != 0) + { + json.WriteString(" DEVICE_COHERENT"); + } + if((flags & VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY) != 0) + { + json.WriteString(" DEVICE_UNCACHED"); + } + json.EndArray(); + + if(stats.memoryType[typeIndex].blockCount > 0) + { + json.WriteString("Stats"); + VmaPrintStatInfo(json, stats.memoryType[typeIndex]); + } + + json.EndObject(); + } + } + + json.EndObject(); + } + if(detailedMap == VK_TRUE) + { + allocator->PrintDetailedMap(json); + } + + json.EndObject(); + } + + const size_t len = sb.GetLength(); + char* const pChars = vma_new_array(allocator, char, len + 1); + if(len > 0) + { + memcpy(pChars, sb.GetData(), len); + } + pChars[len] = '\0'; + *ppStatsString = pChars; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaFreeStatsString( + VmaAllocator allocator, + char* pStatsString) +{ + if(pStatsString != VMA_NULL) + { + VMA_ASSERT(allocator); + size_t len = strlen(pStatsString); + vma_delete_array(allocator, pStatsString, len + 1); + } +} + +#endif // #if VMA_STATS_STRING_ENABLED + +/* +This function is not protected by any mutex because it just reads immutable data. +*/ +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndex( + VmaAllocator allocator, + uint32_t memoryTypeBits, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) +{ + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); + + memoryTypeBits &= allocator->GetGlobalMemoryTypeBits(); + + if(pAllocationCreateInfo->memoryTypeBits != 0) + { + memoryTypeBits &= pAllocationCreateInfo->memoryTypeBits; + } + + uint32_t requiredFlags = pAllocationCreateInfo->requiredFlags; + uint32_t preferredFlags = pAllocationCreateInfo->preferredFlags; + uint32_t notPreferredFlags = 0; + + // Convert usage to requiredFlags and preferredFlags. + switch(pAllocationCreateInfo->usage) + { + case VMA_MEMORY_USAGE_UNKNOWN: + break; + case VMA_MEMORY_USAGE_GPU_ONLY: + if(!allocator->IsIntegratedGpu() || (preferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + preferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + break; + case VMA_MEMORY_USAGE_CPU_ONLY: + requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + break; + case VMA_MEMORY_USAGE_CPU_TO_GPU: + requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + if(!allocator->IsIntegratedGpu() || (preferredFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT) == 0) + { + preferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + } + break; + case VMA_MEMORY_USAGE_GPU_TO_CPU: + requiredFlags |= VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT; + preferredFlags |= VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + break; + case VMA_MEMORY_USAGE_CPU_COPY: + notPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT; + break; + case VMA_MEMORY_USAGE_GPU_LAZILY_ALLOCATED: + requiredFlags |= VK_MEMORY_PROPERTY_LAZILY_ALLOCATED_BIT; + break; + default: + VMA_ASSERT(0); + break; + } + + // Avoid DEVICE_COHERENT unless explicitly requested. + if(((pAllocationCreateInfo->requiredFlags | pAllocationCreateInfo->preferredFlags) & + (VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY | VK_MEMORY_PROPERTY_DEVICE_UNCACHED_BIT_AMD_COPY)) == 0) + { + notPreferredFlags |= VK_MEMORY_PROPERTY_DEVICE_COHERENT_BIT_AMD_COPY; + } + + *pMemoryTypeIndex = UINT32_MAX; + uint32_t minCost = UINT32_MAX; + for(uint32_t memTypeIndex = 0, memTypeBit = 1; + memTypeIndex < allocator->GetMemoryTypeCount(); + ++memTypeIndex, memTypeBit <<= 1) + { + // This memory type is acceptable according to memoryTypeBits bitmask. + if((memTypeBit & memoryTypeBits) != 0) + { + const VkMemoryPropertyFlags currFlags = + allocator->m_MemProps.memoryTypes[memTypeIndex].propertyFlags; + // This memory type contains requiredFlags. + if((requiredFlags & ~currFlags) == 0) + { + // Calculate cost as number of bits from preferredFlags not present in this memory type. + uint32_t currCost = VmaCountBitsSet(preferredFlags & ~currFlags) + + VmaCountBitsSet(currFlags & notPreferredFlags); + // Remember memory type with lowest cost. + if(currCost < minCost) + { + *pMemoryTypeIndex = memTypeIndex; + if(currCost == 0) + { + return VK_SUCCESS; + } + minCost = currCost; + } + } + } + } + return (*pMemoryTypeIndex != UINT32_MAX) ? VK_SUCCESS : VK_ERROR_FEATURE_NOT_PRESENT; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForBufferInfo( + VmaAllocator allocator, + const VkBufferCreateInfo* pBufferCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) +{ + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pBufferCreateInfo != VMA_NULL); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); + + const VkDevice hDev = allocator->m_hDevice; + VkBuffer hBuffer = VK_NULL_HANDLE; + VkResult res = allocator->GetVulkanFunctions().vkCreateBuffer( + hDev, pBufferCreateInfo, allocator->GetAllocationCallbacks(), &hBuffer); + if(res == VK_SUCCESS) + { + VkMemoryRequirements memReq = {}; + allocator->GetVulkanFunctions().vkGetBufferMemoryRequirements( + hDev, hBuffer, &memReq); + + res = vmaFindMemoryTypeIndex( + allocator, + memReq.memoryTypeBits, + pAllocationCreateInfo, + pMemoryTypeIndex); + + allocator->GetVulkanFunctions().vkDestroyBuffer( + hDev, hBuffer, allocator->GetAllocationCallbacks()); + } + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFindMemoryTypeIndexForImageInfo( + VmaAllocator allocator, + const VkImageCreateInfo* pImageCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + uint32_t* pMemoryTypeIndex) +{ + VMA_ASSERT(allocator != VK_NULL_HANDLE); + VMA_ASSERT(pImageCreateInfo != VMA_NULL); + VMA_ASSERT(pAllocationCreateInfo != VMA_NULL); + VMA_ASSERT(pMemoryTypeIndex != VMA_NULL); + + const VkDevice hDev = allocator->m_hDevice; + VkImage hImage = VK_NULL_HANDLE; + VkResult res = allocator->GetVulkanFunctions().vkCreateImage( + hDev, pImageCreateInfo, allocator->GetAllocationCallbacks(), &hImage); + if(res == VK_SUCCESS) + { + VkMemoryRequirements memReq = {}; + allocator->GetVulkanFunctions().vkGetImageMemoryRequirements( + hDev, hImage, &memReq); + + res = vmaFindMemoryTypeIndex( + allocator, + memReq.memoryTypeBits, + pAllocationCreateInfo, + pMemoryTypeIndex); + + allocator->GetVulkanFunctions().vkDestroyImage( + hDev, hImage, allocator->GetAllocationCallbacks()); + } + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreatePool( + VmaAllocator allocator, + const VmaPoolCreateInfo* pCreateInfo, + VmaPool* pPool) +{ + VMA_ASSERT(allocator && pCreateInfo && pPool); + + VMA_DEBUG_LOG("vmaCreatePool"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkResult res = allocator->CreatePool(pCreateInfo, pPool); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordCreatePool(allocator->GetCurrentFrameIndex(), *pCreateInfo, *pPool); + } +#endif + + return res; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyPool( + VmaAllocator allocator, + VmaPool pool) +{ + VMA_ASSERT(allocator); + + if(pool == VK_NULL_HANDLE) + { + return; + } + + VMA_DEBUG_LOG("vmaDestroyPool"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordDestroyPool(allocator->GetCurrentFrameIndex(), pool); + } +#endif + + allocator->DestroyPool(pool); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolStats( + VmaAllocator allocator, + VmaPool pool, + VmaPoolStats* pPoolStats) +{ + VMA_ASSERT(allocator && pool && pPoolStats); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocator->GetPoolStats(pool, pPoolStats); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaMakePoolAllocationsLost( + VmaAllocator allocator, + VmaPool pool, + size_t* pLostAllocationCount) +{ + VMA_ASSERT(allocator && pool); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordMakePoolAllocationsLost(allocator->GetCurrentFrameIndex(), pool); + } +#endif + + allocator->MakePoolAllocationsLost(pool, pLostAllocationCount); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckPoolCorruption(VmaAllocator allocator, VmaPool pool) +{ + VMA_ASSERT(allocator && pool); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VMA_DEBUG_LOG("vmaCheckPoolCorruption"); + + return allocator->CheckPoolCorruption(pool); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetPoolName( + VmaAllocator allocator, + VmaPool pool, + const char** ppName) +{ + VMA_ASSERT(allocator && pool && ppName); + + VMA_DEBUG_LOG("vmaGetPoolName"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + *ppName = pool->GetName(); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaSetPoolName( + VmaAllocator allocator, + VmaPool pool, + const char* pName) +{ + VMA_ASSERT(allocator && pool); + + VMA_DEBUG_LOG("vmaSetPoolName"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + pool->SetName(pName); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordSetPoolName(allocator->GetCurrentFrameIndex(), pool, pName); + } +#endif +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemory( + VmaAllocator allocator, + const VkMemoryRequirements* pVkMemoryRequirements, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocation); + + VMA_DEBUG_LOG("vmaAllocateMemory"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkResult result = allocator->AllocateMemory( + *pVkMemoryRequirements, + false, // requiresDedicatedAllocation + false, // prefersDedicatedAllocation + VK_NULL_HANDLE, // dedicatedBuffer + UINT32_MAX, // dedicatedBufferUsage + VK_NULL_HANDLE, // dedicatedImage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_UNKNOWN, + 1, // allocationCount + pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordAllocateMemory( + allocator->GetCurrentFrameIndex(), + *pVkMemoryRequirements, + *pCreateInfo, + *pAllocation); + } +#endif + + if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } + + return result; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryPages( + VmaAllocator allocator, + const VkMemoryRequirements* pVkMemoryRequirements, + const VmaAllocationCreateInfo* pCreateInfo, + size_t allocationCount, + VmaAllocation* pAllocations, + VmaAllocationInfo* pAllocationInfo) +{ + if(allocationCount == 0) + { + return VK_SUCCESS; + } + + VMA_ASSERT(allocator && pVkMemoryRequirements && pCreateInfo && pAllocations); + + VMA_DEBUG_LOG("vmaAllocateMemoryPages"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkResult result = allocator->AllocateMemory( + *pVkMemoryRequirements, + false, // requiresDedicatedAllocation + false, // prefersDedicatedAllocation + VK_NULL_HANDLE, // dedicatedBuffer + UINT32_MAX, // dedicatedBufferUsage + VK_NULL_HANDLE, // dedicatedImage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_UNKNOWN, + allocationCount, + pAllocations); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordAllocateMemoryPages( + allocator->GetCurrentFrameIndex(), + *pVkMemoryRequirements, + *pCreateInfo, + (uint64_t)allocationCount, + pAllocations); + } +#endif + + if(pAllocationInfo != VMA_NULL && result == VK_SUCCESS) + { + for(size_t i = 0; i < allocationCount; ++i) + { + allocator->GetAllocationInfo(pAllocations[i], pAllocationInfo + i); + } + } + + return result; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForBuffer( + VmaAllocator allocator, + VkBuffer buffer, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && buffer != VK_NULL_HANDLE && pCreateInfo && pAllocation); + + VMA_DEBUG_LOG("vmaAllocateMemoryForBuffer"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetBufferMemoryRequirements(buffer, vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation); + + VkResult result = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + buffer, // dedicatedBuffer + UINT32_MAX, // dedicatedBufferUsage + VK_NULL_HANDLE, // dedicatedImage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_BUFFER, + 1, // allocationCount + pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordAllocateMemoryForBuffer( + allocator->GetCurrentFrameIndex(), + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + *pCreateInfo, + *pAllocation); + } +#endif + + if(pAllocationInfo && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } + + return result; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaAllocateMemoryForImage( + VmaAllocator allocator, + VkImage image, + const VmaAllocationCreateInfo* pCreateInfo, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && image != VK_NULL_HANDLE && pCreateInfo && pAllocation); + + VMA_DEBUG_LOG("vmaAllocateMemoryForImage"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetImageMemoryRequirements(image, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); + + VkResult result = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + VK_NULL_HANDLE, // dedicatedBuffer + UINT32_MAX, // dedicatedBufferUsage + image, // dedicatedImage + *pCreateInfo, + VMA_SUBALLOCATION_TYPE_IMAGE_UNKNOWN, + 1, // allocationCount + pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordAllocateMemoryForImage( + allocator->GetCurrentFrameIndex(), + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + *pCreateInfo, + *pAllocation); + } +#endif + + if(pAllocationInfo && result == VK_SUCCESS) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } + + return result; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemory( + VmaAllocator allocator, + VmaAllocation allocation) +{ + VMA_ASSERT(allocator); + + if(allocation == VK_NULL_HANDLE) + { + return; + } + + VMA_DEBUG_LOG("vmaFreeMemory"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordFreeMemory( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + allocator->FreeMemory( + 1, // allocationCount + &allocation); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaFreeMemoryPages( + VmaAllocator allocator, + size_t allocationCount, + const VmaAllocation* pAllocations) +{ + if(allocationCount == 0) + { + return; + } + + VMA_ASSERT(allocator); + + VMA_DEBUG_LOG("vmaFreeMemoryPages"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordFreeMemoryPages( + allocator->GetCurrentFrameIndex(), + (uint64_t)allocationCount, + pAllocations); + } +#endif + + allocator->FreeMemory(allocationCount, pAllocations); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaResizeAllocation( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize newSize) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_LOG("vmaResizeAllocation"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->ResizeAllocation(allocation, newSize); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaGetAllocationInfo( + VmaAllocator allocator, + VmaAllocation allocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && allocation && pAllocationInfo); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordGetAllocationInfo( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + allocator->GetAllocationInfo(allocation, pAllocationInfo); +} + +VMA_CALL_PRE VkBool32 VMA_CALL_POST vmaTouchAllocation( + VmaAllocator allocator, + VmaAllocation allocation) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordTouchAllocation( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + return allocator->TouchAllocation(allocation); +} + +VMA_CALL_PRE void VMA_CALL_POST vmaSetAllocationUserData( + VmaAllocator allocator, + VmaAllocation allocation, + void* pUserData) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + allocation->SetUserData(allocator, pUserData); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordSetAllocationUserData( + allocator->GetCurrentFrameIndex(), + allocation, + pUserData); + } +#endif +} + +VMA_CALL_PRE void VMA_CALL_POST vmaCreateLostAllocation( + VmaAllocator allocator, + VmaAllocation* pAllocation) +{ + VMA_ASSERT(allocator && pAllocation); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK; + + allocator->CreateLostAllocation(pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordCreateLostAllocation( + allocator->GetCurrentFrameIndex(), + *pAllocation); + } +#endif +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaMapMemory( + VmaAllocator allocator, + VmaAllocation allocation, + void** ppData) +{ + VMA_ASSERT(allocator && allocation && ppData); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkResult res = allocator->Map(allocation, ppData); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordMapMemory( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + return res; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaUnmapMemory( + VmaAllocator allocator, + VmaAllocation allocation) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordUnmapMemory( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + allocator->Unmap(allocation); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocation(VmaAllocator allocator, VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_LOG("vmaFlushAllocation"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_FLUSH); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordFlushAllocation( + allocator->GetCurrentFrameIndex(), + allocation, offset, size); + } +#endif + + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocation(VmaAllocator allocator, VmaAllocation allocation, VkDeviceSize offset, VkDeviceSize size) +{ + VMA_ASSERT(allocator && allocation); + + VMA_DEBUG_LOG("vmaInvalidateAllocation"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocation(allocation, offset, size, VMA_CACHE_INVALIDATE); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordInvalidateAllocation( + allocator->GetCurrentFrameIndex(), + allocation, offset, size); + } +#endif + + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaFlushAllocations( + VmaAllocator allocator, + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, + const VkDeviceSize* sizes) +{ + VMA_ASSERT(allocator); + + if(allocationCount == 0) + { + return VK_SUCCESS; + } + + VMA_ASSERT(allocations); + + VMA_DEBUG_LOG("vmaFlushAllocations"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_FLUSH); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + //TODO + } +#endif + + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaInvalidateAllocations( + VmaAllocator allocator, + uint32_t allocationCount, + const VmaAllocation* allocations, + const VkDeviceSize* offsets, + const VkDeviceSize* sizes) +{ + VMA_ASSERT(allocator); + + if(allocationCount == 0) + { + return VK_SUCCESS; + } + + VMA_ASSERT(allocations); + + VMA_DEBUG_LOG("vmaInvalidateAllocations"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + const VkResult res = allocator->FlushOrInvalidateAllocations(allocationCount, allocations, offsets, sizes, VMA_CACHE_INVALIDATE); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + //TODO + } +#endif + + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCheckCorruption(VmaAllocator allocator, uint32_t memoryTypeBits) +{ + VMA_ASSERT(allocator); + + VMA_DEBUG_LOG("vmaCheckCorruption"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->CheckCorruption(memoryTypeBits); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragment( + VmaAllocator allocator, + const VmaAllocation* pAllocations, + size_t allocationCount, + VkBool32* pAllocationsChanged, + const VmaDefragmentationInfo *pDefragmentationInfo, + VmaDefragmentationStats* pDefragmentationStats) +{ + // Deprecated interface, reimplemented using new one. + + VmaDefragmentationInfo2 info2 = {}; + info2.allocationCount = (uint32_t)allocationCount; + info2.pAllocations = pAllocations; + info2.pAllocationsChanged = pAllocationsChanged; + if(pDefragmentationInfo != VMA_NULL) + { + info2.maxCpuAllocationsToMove = pDefragmentationInfo->maxAllocationsToMove; + info2.maxCpuBytesToMove = pDefragmentationInfo->maxBytesToMove; + } + else + { + info2.maxCpuAllocationsToMove = UINT32_MAX; + info2.maxCpuBytesToMove = VK_WHOLE_SIZE; + } + // info2.flags, maxGpuAllocationsToMove, maxGpuBytesToMove, commandBuffer deliberately left zero. + + VmaDefragmentationContext ctx; + VkResult res = vmaDefragmentationBegin(allocator, &info2, pDefragmentationStats, &ctx); + if(res == VK_NOT_READY) + { + res = vmaDefragmentationEnd( allocator, ctx); + } + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationBegin( + VmaAllocator allocator, + const VmaDefragmentationInfo2* pInfo, + VmaDefragmentationStats* pStats, + VmaDefragmentationContext *pContext) +{ + VMA_ASSERT(allocator && pInfo && pContext); + + // Degenerate case: Nothing to defragment. + if(pInfo->allocationCount == 0 && pInfo->poolCount == 0) + { + return VK_SUCCESS; + } + + VMA_ASSERT(pInfo->allocationCount == 0 || pInfo->pAllocations != VMA_NULL); + VMA_ASSERT(pInfo->poolCount == 0 || pInfo->pPools != VMA_NULL); + VMA_HEAVY_ASSERT(VmaValidatePointerArray(pInfo->allocationCount, pInfo->pAllocations)); + VMA_HEAVY_ASSERT(VmaValidatePointerArray(pInfo->poolCount, pInfo->pPools)); + + VMA_DEBUG_LOG("vmaDefragmentationBegin"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + VkResult res = allocator->DefragmentationBegin(*pInfo, pStats, pContext); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordDefragmentationBegin( + allocator->GetCurrentFrameIndex(), *pInfo, *pContext); + } +#endif + + return res; +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaDefragmentationEnd( + VmaAllocator allocator, + VmaDefragmentationContext context) +{ + VMA_ASSERT(allocator); + + VMA_DEBUG_LOG("vmaDefragmentationEnd"); + + if(context != VK_NULL_HANDLE) + { + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordDefragmentationEnd( + allocator->GetCurrentFrameIndex(), context); + } +#endif + + return allocator->DefragmentationEnd(context); + } + else + { + return VK_SUCCESS; + } +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBeginDefragmentationPass( + VmaAllocator allocator, + VmaDefragmentationContext context, + VmaDefragmentationPassInfo* pInfo + ) +{ + VMA_ASSERT(allocator); + VMA_ASSERT(pInfo); + + VMA_DEBUG_LOG("vmaBeginDefragmentationPass"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + if(context == VK_NULL_HANDLE) + { + pInfo->moveCount = 0; + return VK_SUCCESS; + } + + return allocator->DefragmentationPassBegin(pInfo, context); +} +VMA_CALL_PRE VkResult VMA_CALL_POST vmaEndDefragmentationPass( + VmaAllocator allocator, + VmaDefragmentationContext context) +{ + VMA_ASSERT(allocator); + + VMA_DEBUG_LOG("vmaEndDefragmentationPass"); + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + if(context == VK_NULL_HANDLE) + return VK_SUCCESS; + + return allocator->DefragmentationPassEnd(context); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory( + VmaAllocator allocator, + VmaAllocation allocation, + VkBuffer buffer) +{ + VMA_ASSERT(allocator && allocation && buffer); + + VMA_DEBUG_LOG("vmaBindBufferMemory"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->BindBufferMemory(allocation, 0, buffer, VMA_NULL); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindBufferMemory2( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize allocationLocalOffset, + VkBuffer buffer, + const void* pNext) +{ + VMA_ASSERT(allocator && allocation && buffer); + + VMA_DEBUG_LOG("vmaBindBufferMemory2"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->BindBufferMemory(allocation, allocationLocalOffset, buffer, pNext); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory( + VmaAllocator allocator, + VmaAllocation allocation, + VkImage image) +{ + VMA_ASSERT(allocator && allocation && image); + + VMA_DEBUG_LOG("vmaBindImageMemory"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->BindImageMemory(allocation, 0, image, VMA_NULL); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaBindImageMemory2( + VmaAllocator allocator, + VmaAllocation allocation, + VkDeviceSize allocationLocalOffset, + VkImage image, + const void* pNext) +{ + VMA_ASSERT(allocator && allocation && image); + + VMA_DEBUG_LOG("vmaBindImageMemory2"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + return allocator->BindImageMemory(allocation, allocationLocalOffset, image, pNext); +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateBuffer( + VmaAllocator allocator, + const VkBufferCreateInfo* pBufferCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkBuffer* pBuffer, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && pBufferCreateInfo && pAllocationCreateInfo && pBuffer && pAllocation); + + if(pBufferCreateInfo->size == 0) + { + return VK_ERROR_VALIDATION_FAILED_EXT; + } + if((pBufferCreateInfo->usage & VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT_COPY) != 0 && + !allocator->m_UseKhrBufferDeviceAddress) + { + VMA_ASSERT(0 && "Creating a buffer with VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT is not valid if VMA_ALLOCATOR_CREATE_BUFFER_DEVICE_ADDRESS_BIT was not used."); + return VK_ERROR_VALIDATION_FAILED_EXT; + } + + VMA_DEBUG_LOG("vmaCreateBuffer"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + *pBuffer = VK_NULL_HANDLE; + *pAllocation = VK_NULL_HANDLE; + + // 1. Create VkBuffer. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateBuffer)( + allocator->m_hDevice, + pBufferCreateInfo, + allocator->GetAllocationCallbacks(), + pBuffer); + if(res >= 0) + { + // 2. vkGetBufferMemoryRequirements. + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetBufferMemoryRequirements(*pBuffer, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); + + // 3. Allocate memory using allocator. + res = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + *pBuffer, // dedicatedBuffer + pBufferCreateInfo->usage, // dedicatedBufferUsage + VK_NULL_HANDLE, // dedicatedImage + *pAllocationCreateInfo, + VMA_SUBALLOCATION_TYPE_BUFFER, + 1, // allocationCount + pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordCreateBuffer( + allocator->GetCurrentFrameIndex(), + *pBufferCreateInfo, + *pAllocationCreateInfo, + *pAllocation); + } +#endif + + if(res >= 0) + { + // 3. Bind buffer with memory. + if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) + { + res = allocator->BindBufferMemory(*pAllocation, 0, *pBuffer, VMA_NULL); + } + if(res >= 0) + { + // All steps succeeded. + #if VMA_STATS_STRING_ENABLED + (*pAllocation)->InitBufferImageUsage(pBufferCreateInfo->usage); + #endif + if(pAllocationInfo != VMA_NULL) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } + + return VK_SUCCESS; + } + allocator->FreeMemory( + 1, // allocationCount + pAllocation); + *pAllocation = VK_NULL_HANDLE; + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; + return res; + } + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, *pBuffer, allocator->GetAllocationCallbacks()); + *pBuffer = VK_NULL_HANDLE; + return res; + } + return res; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyBuffer( + VmaAllocator allocator, + VkBuffer buffer, + VmaAllocation allocation) +{ + VMA_ASSERT(allocator); + + if(buffer == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) + { + return; + } + + VMA_DEBUG_LOG("vmaDestroyBuffer"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordDestroyBuffer( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + if(buffer != VK_NULL_HANDLE) + { + (*allocator->GetVulkanFunctions().vkDestroyBuffer)(allocator->m_hDevice, buffer, allocator->GetAllocationCallbacks()); + } + + if(allocation != VK_NULL_HANDLE) + { + allocator->FreeMemory( + 1, // allocationCount + &allocation); + } +} + +VMA_CALL_PRE VkResult VMA_CALL_POST vmaCreateImage( + VmaAllocator allocator, + const VkImageCreateInfo* pImageCreateInfo, + const VmaAllocationCreateInfo* pAllocationCreateInfo, + VkImage* pImage, + VmaAllocation* pAllocation, + VmaAllocationInfo* pAllocationInfo) +{ + VMA_ASSERT(allocator && pImageCreateInfo && pAllocationCreateInfo && pImage && pAllocation); + + if(pImageCreateInfo->extent.width == 0 || + pImageCreateInfo->extent.height == 0 || + pImageCreateInfo->extent.depth == 0 || + pImageCreateInfo->mipLevels == 0 || + pImageCreateInfo->arrayLayers == 0) + { + return VK_ERROR_VALIDATION_FAILED_EXT; + } + + VMA_DEBUG_LOG("vmaCreateImage"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + + *pImage = VK_NULL_HANDLE; + *pAllocation = VK_NULL_HANDLE; + + // 1. Create VkImage. + VkResult res = (*allocator->GetVulkanFunctions().vkCreateImage)( + allocator->m_hDevice, + pImageCreateInfo, + allocator->GetAllocationCallbacks(), + pImage); + if(res >= 0) + { + VmaSuballocationType suballocType = pImageCreateInfo->tiling == VK_IMAGE_TILING_OPTIMAL ? + VMA_SUBALLOCATION_TYPE_IMAGE_OPTIMAL : + VMA_SUBALLOCATION_TYPE_IMAGE_LINEAR; + + // 2. Allocate memory using allocator. + VkMemoryRequirements vkMemReq = {}; + bool requiresDedicatedAllocation = false; + bool prefersDedicatedAllocation = false; + allocator->GetImageMemoryRequirements(*pImage, vkMemReq, + requiresDedicatedAllocation, prefersDedicatedAllocation); + + res = allocator->AllocateMemory( + vkMemReq, + requiresDedicatedAllocation, + prefersDedicatedAllocation, + VK_NULL_HANDLE, // dedicatedBuffer + UINT32_MAX, // dedicatedBufferUsage + *pImage, // dedicatedImage + *pAllocationCreateInfo, + suballocType, + 1, // allocationCount + pAllocation); + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordCreateImage( + allocator->GetCurrentFrameIndex(), + *pImageCreateInfo, + *pAllocationCreateInfo, + *pAllocation); + } +#endif + + if(res >= 0) + { + // 3. Bind image with memory. + if((pAllocationCreateInfo->flags & VMA_ALLOCATION_CREATE_DONT_BIND_BIT) == 0) + { + res = allocator->BindImageMemory(*pAllocation, 0, *pImage, VMA_NULL); + } + if(res >= 0) + { + // All steps succeeded. + #if VMA_STATS_STRING_ENABLED + (*pAllocation)->InitBufferImageUsage(pImageCreateInfo->usage); + #endif + if(pAllocationInfo != VMA_NULL) + { + allocator->GetAllocationInfo(*pAllocation, pAllocationInfo); + } + + return VK_SUCCESS; + } + allocator->FreeMemory( + 1, // allocationCount + pAllocation); + *pAllocation = VK_NULL_HANDLE; + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); + *pImage = VK_NULL_HANDLE; + return res; + } + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, *pImage, allocator->GetAllocationCallbacks()); + *pImage = VK_NULL_HANDLE; + return res; + } + return res; +} + +VMA_CALL_PRE void VMA_CALL_POST vmaDestroyImage( + VmaAllocator allocator, + VkImage image, + VmaAllocation allocation) +{ + VMA_ASSERT(allocator); + + if(image == VK_NULL_HANDLE && allocation == VK_NULL_HANDLE) + { + return; + } + + VMA_DEBUG_LOG("vmaDestroyImage"); + + VMA_DEBUG_GLOBAL_MUTEX_LOCK + +#if VMA_RECORDING_ENABLED + if(allocator->GetRecorder() != VMA_NULL) + { + allocator->GetRecorder()->RecordDestroyImage( + allocator->GetCurrentFrameIndex(), + allocation); + } +#endif + + if(image != VK_NULL_HANDLE) + { + (*allocator->GetVulkanFunctions().vkDestroyImage)(allocator->m_hDevice, image, allocator->GetAllocationCallbacks()); + } + if(allocation != VK_NULL_HANDLE) + { + allocator->FreeMemory( + 1, // allocationCount + &allocation); + } +} + +#endif // #ifdef VMA_IMPLEMENTATION diff --git a/aten/src/ATen/native/xnnpack/Common.h b/aten/src/ATen/native/xnnpack/Common.h index 90984990c2ef7a6..789f69e607ab6d7 100644 --- a/aten/src/ATen/native/xnnpack/Common.h +++ b/aten/src/ATen/native/xnnpack/Common.h @@ -33,15 +33,19 @@ struct ContextLinear final { static constexpr float kMax = std::numeric_limits::infinity(); }; +// This contains information for both the transpose and non-transpose cases. struct ContextConv2D final { Operator op; std::array weight_size_; std::array padding_; + std::array output_padding_; std::array stride_; std::array dilation_; const float* cached_input_ptr{nullptr}; const float* cached_output_ptr{nullptr}; size_t input_height{0}, input_width{0}, batch_size{0}, input_channels{0}; + bool transposed_; + int64_t groups_; ContextConv2D() = delete; @@ -49,13 +53,19 @@ struct ContextConv2D final { Operator&& o, std::array weight_size, std::array padding, + std::array output_padding, std::array stride, - std::array dilation) + std::array dilation, + bool transposed, + int64_t groups) : op(std::move(o)), weight_size_(weight_size), padding_(padding), + output_padding_(output_padding), stride_(stride), - dilation_(dilation) {} + dilation_(dilation), + transposed_(transposed), + groups_(groups) {} static constexpr float kMin = -std::numeric_limits::infinity(); static constexpr float kMax = std::numeric_limits::infinity(); }; diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index 779f6e100143cad..59f129e3ee8947e 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -31,6 +31,7 @@ bool available( const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups, + const bool transposed, const float output_min, const float output_max) { // XNNPACK @@ -43,9 +44,10 @@ bool available( (kFloat == weight.scalar_type()) && // Bias ((bias && bias->defined()) ? ((1 == bias->ndimension()) && - (c10::DeviceType::CPU == bias->device().type()) && - (kFloat == bias->scalar_type()) && - (weight.size(Layout::Filter::output)) == bias->size(0)) + (c10::DeviceType::CPU == bias->device().type()) && + (kFloat == bias->scalar_type()) && + ((transposed ? (weight.size(Layout::Filter::input) == (bias->size(0) / groups)) + : (weight.size(Layout::Filter::output) == (bias->size(0)))))) : true) && // Padding (padding[Layout::Parameter::height] >= 0) && @@ -88,35 +90,97 @@ Tensor create_and_run( const Tensor& weight, const Tensor& bias, const IntArrayRef padding, + const IntArrayRef output_padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups, + const bool transposed, const float output_min, const float output_max) { auto op_context = create( weight, bias, padding, + output_padding, stride, dilation, groups, + transposed, output_min, output_max); return run(op_context, input); } +// XNNPack's deconvolution operator expects weights to be indexed in the following order: +// * Groups +// * Group Output Channels +// * Kernel Height +// * Kernel Width +// * Group Input Channels +// +// (ref: https://github.com/google/XNNPACK/blob/ecd8311c8fd3d9ab47edbc3df5f2b5de7dabe75f/test/deconvolution-operator-tester.h#L678) +// +// This function takes in a contiguous NHWC pytorch tensor (e.g. MemoryFormat == ChannelsLast) and rearranges the weights in preparation for use with xnnpack. +// By default, for pytorch, transpose conv2d weights are {input_channels, output_Channels_per_group, kernel_height, kernel_width}. +// In addition, it condenses the tensor from 5 to 4 dimensions as expected by the rest of the pytorch framework by combining the groups and input_channels dimension. +const Tensor reorder_weights_for_transpose_conv(const Tensor& weight_nhwc, + int num_groups) { + + TORCH_CHECK(weight_nhwc.size(0) % num_groups == 0, "The number of groups cannot be satisfied by the provided weight tensor."); + + int input_channels_per_group = weight_nhwc.size(0) / num_groups; + int output_channels_per_group = weight_nhwc.size(1); + int kernel_width = weight_nhwc.size(3); + int kernel_height = weight_nhwc.size(2); + + int o_offset = 1; + int h_offset = (output_channels_per_group); + int w_offset = (output_channels_per_group)*(kernel_height); + int i_offset = (output_channels_per_group)*(kernel_height)*(kernel_width); + int g_offset = (output_channels_per_group)*(kernel_height)*(kernel_width)*(input_channels_per_group); + + Tensor reordered = mobile::empty_with_tail_padding( + weight_nhwc.sizes(), + weight_nhwc.options().dtype(), + MemoryFormat::ChannelsLast, + weight_nhwc.names()); + + float* out_ptr = reordered.data_ptr(); + float* in_ptr = weight_nhwc.data_ptr(); + + int out_index = 0; + for (int g = 0; g < num_groups; g++) { + for (int o = 0; o < output_channels_per_group; o++) { + for (int w = 0; w < kernel_width; w++) { + for (int h = 0; h < kernel_height; h++) { + for (int i = 0; i < input_channels_per_group; i++) { + int in_index = (g*g_offset) + (i*i_offset) + (h*h_offset) + (w*w_offset) + (o*o_offset); + out_ptr[out_index] = in_ptr[in_index]; + out_index++; + } + } + } + } + } + + return reordered; +} + } // namespace ContextConv2D create( const Tensor& weight, const c10::optional& bias, const IntArrayRef padding, + const IntArrayRef output_padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups, + const bool transposed, const float output_min, const float output_max) { const auto padding_expanded = expand_param_if_needed(padding, "padding", 2); + const auto output_padding_expanded = expand_param_if_needed(output_padding, "output_padding", 2); const auto stride_expanded = expand_param_if_needed(stride, "stride", 2); const auto dilation_expanded = expand_param_if_needed(dilation, "dilation", 2); const Tensor weight_nhwc = weight.contiguous(MemoryFormat::ChannelsLast); @@ -129,15 +193,52 @@ ContextConv2D create( stride_expanded, dilation_expanded, groups, + transposed, output_min, output_max), "xnnpack::convolution not available! " - "Reason: The provided (weight, bias, padding, stride, dilation, groups, output_min, output_max) " + "Reason: The provided (weight, bias, padding, stride, dilation, groups, transposed, output_min, output_max) " "parameters are either invalid individually or their combination is not supported by XNNPACK."); - xnn_operator_t convolution_op{}; - const xnn_status create_status = xnn_create_convolution2d_nhwc_f32( + xnn_operator_t convolution_op{}; + xnn_status create_status; + std::array weight_sizes; + + if (transposed) { + const Tensor weight_reordered = reorder_weights_for_transpose_conv(weight_nhwc, groups); + for (int i = 0; i < 4; i++) { + weight_sizes[i] = weight_reordered.size(i); + } + create_status = xnn_create_deconvolution2d_nhwc_f32( + padding_expanded[Layout::Parameter::height], // output_padding_top + padding_expanded[Layout::Parameter::width], // output_padding_right + padding_expanded[Layout::Parameter::height], // output_padding_bottom + padding_expanded[Layout::Parameter::width], // output_padding_left + weight_reordered.size(Layout::Filter::height), // kernel_height + weight_reordered.size(Layout::Filter::width), // kernel_width + stride_expanded[Layout::Parameter::height], // subsampling_height + stride_expanded[Layout::Parameter::width], // subsampling_width + dilation_expanded[Layout::Parameter::height], // dilation_height + dilation_expanded[Layout::Parameter::width], // dilation_width + groups, // groups + weight_reordered.size(Layout::Filter::output) / groups, // group_input_channels + weight_reordered.size(Layout::Filter::input), // group_output_channels + weight_reordered.size(Layout::Filter::output), // input_pixel_stride + weight_reordered.size(Layout::Filter::input) * groups, // output_pixel_stride + weight_reordered.data_ptr(), // kernel + (bias && bias->defined()) + ? bias->contiguous().data_ptr() + : nullptr, // bias + output_min, // output_min + output_max, // output_max + 0u, // flags + &convolution_op); // operator + } else { + for (int i = 0; i < 4; i++) { + weight_sizes[i] = weight_nhwc.size(i); + } + create_status = xnn_create_convolution2d_nhwc_f32( padding_expanded[Layout::Parameter::height], // input_padding_top padding_expanded[Layout::Parameter::width], // input_padding_right padding_expanded[Layout::Parameter::height], // input_padding_bottom @@ -161,18 +262,21 @@ ContextConv2D create( output_max, // output_max 0u, // flags &convolution_op); // operator + } TORCH_CHECK( xnn_status_success == create_status, - "xnn_create_convolution2d_nhwc_f32 failed!"); + (transposed ? "xnn_create_deconvolution2d_nhwc_f32 failed!" + : "xnn_create_convolution2d_nhwc_f32 failed!")); return ContextConv2D{ Operator(convolution_op), - {weight_nhwc.sizes()[0], weight_nhwc.sizes()[1], - weight_nhwc.sizes()[2], weight_nhwc.sizes()[3]}, + weight_sizes, {padding_expanded[0], padding_expanded[1]}, + {output_padding_expanded[0], output_padding_expanded[1]}, {stride_expanded[0], stride_expanded[1]}, - {dilation_expanded[0], dilation_expanded[1]} + {dilation_expanded[0], dilation_expanded[1]}, + transposed, groups }; } @@ -189,7 +293,21 @@ Tensor run( "XNNPACK Convolution not usable! " "Reason: The provided input tensor is either invalid or unsupported by XNNPACK."); - Tensor output = mobile::empty_with_tail_padding( + Tensor output; + if (context.transposed_) { + output = mobile::empty_with_tail_padding( + conv_input_size(padded_input_nhwc.sizes(), + context.weight_size_, + context.padding_, + context.output_padding_, + context.stride_, + context.dilation_, + context.groups_), + padded_input_nhwc.options().dtype(), + MemoryFormat::ChannelsLast, + padded_input_nhwc.names()); + } else { + output = mobile::empty_with_tail_padding( conv_output_size( padded_input_nhwc.sizes(), context.weight_size_, @@ -199,7 +317,9 @@ Tensor run( padded_input_nhwc.options().dtype(), MemoryFormat::ChannelsLast, padded_input_nhwc.names()); + } + xnn_status setup_status; if ((context.cached_input_ptr != padded_input_nhwc.data_ptr()) || (context.cached_output_ptr != output.data_ptr()) || (padded_input_nhwc.size(Layout::Activation4D::batch) != @@ -211,26 +331,42 @@ Tensor run( (padded_input_nhwc.size(Layout::Activation4D::width) != context.input_width) ) { - const xnn_status setup_status = xnn_setup_convolution2d_nhwc_f32( - context.op.get(), // operator - padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size - padded_input_nhwc.size(Layout::Activation4D::height), // input_height - padded_input_nhwc.size(Layout::Activation4D::width), // input_width - padded_input_nhwc.data_ptr(), // input - output.data_ptr(), // output - caffe2::pthreadpool_()); // threadpool - - TORCH_CHECK( - xnn_status_success == setup_status, - "xnn_setup_convolution2d_nhwc_f32 failed!"); - - // Cache values to avoid setup for the next round. - context.cached_input_ptr = padded_input_nhwc.data_ptr(); - context.cached_output_ptr = output.data_ptr(); - context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch); - context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels); - context.input_height = padded_input_nhwc.size(Layout::Activation4D::height); - context.input_width = padded_input_nhwc.size(Layout::Activation4D::width); + + if (context.transposed_) { + setup_status = xnn_setup_deconvolution2d_nhwc_f32( + context.op.get(), // operator + padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size + padded_input_nhwc.size(Layout::Activation4D::height), // input_height + padded_input_nhwc.size(Layout::Activation4D::width), // input_width + context.output_padding_[0], // adjustment_height + context.output_padding_[1], // adjustment_width + padded_input_nhwc.data_ptr(), // input + output.data_ptr(), // output + caffe2::pthreadpool_()); // threadpool + + } else { + setup_status = xnn_setup_convolution2d_nhwc_f32( + context.op.get(), // operator + padded_input_nhwc.size(Layout::Activation4D::batch), // batch_size + padded_input_nhwc.size(Layout::Activation4D::height), // input_height + padded_input_nhwc.size(Layout::Activation4D::width), // input_width + padded_input_nhwc.data_ptr(), // input + output.data_ptr(), // output + caffe2::pthreadpool_()); + } + + TORCH_CHECK( + xnn_status_success == setup_status, + (context.transposed_ ? "xnn_setup_deconvolution2d_nhwc_f32 failed!" + : "xnn_setup_convolution2d_nhwc_f32 failed!")); + + // Cache values to avoid setup for the next round + context.cached_input_ptr = padded_input_nhwc.data_ptr(); + context.cached_output_ptr = output.data_ptr(); + context.batch_size = padded_input_nhwc.size(Layout::Activation4D::batch); + context.input_channels = padded_input_nhwc.size(Layout::Activation4D::channels); + context.input_height = padded_input_nhwc.size(Layout::Activation4D::height); + context.input_width = padded_input_nhwc.size(Layout::Activation4D::width); } const xnn_status run_status = xnn_run_operator( @@ -265,12 +401,41 @@ c10::intrusive_ptr output_max); } +c10::intrusive_ptr + createConv2dTransposeClampPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector stride, + std::vector padding, + std::vector output_padding, + std::vector dilation, + int64_t groups, + c10::optional output_min, + c10::optional output_max) { + return xnnpack::XNNPackTransposeConv2dOpContext::create_context( + std::move(weight), + std::move(bias), + std::move(padding), + std::move(output_padding), + std::move(stride), + std::move(dilation), + groups, + output_min, + output_max); +} + Tensor conv2d_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context) { return op_context->run(input); } +Tensor conv2d_transpose_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& op_context) { + return op_context->run(input); +} + } // namespace convolution2d } // namespace internal @@ -281,7 +446,8 @@ bool use_convolution2d( const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, - const int64_t groups) { + const int64_t groups, + const bool transposed) { return internal::convolution2d::available( weight, bias, @@ -289,6 +455,7 @@ bool use_convolution2d( stride, dilation, groups, + transposed, ContextConv2D::kMin, ContextConv2D::kMax) && internal::convolution2d::usable(input); @@ -307,9 +474,11 @@ Tensor convolution2d( weight, bias, padding, + {0, 0}, // output_padding stride, dilation, groups, + false, // transposed ContextConv2D::kMin, ContextConv2D::kMax); } diff --git a/aten/src/ATen/native/xnnpack/Convolution.h b/aten/src/ATen/native/xnnpack/Convolution.h index 8a7e2ae65ad1784..bc63a07d309b3fb 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.h +++ b/aten/src/ATen/native/xnnpack/Convolution.h @@ -23,17 +23,35 @@ c10::intrusive_ptr c10::optional output_min, c10::optional output_max); +c10::intrusive_ptr + createConv2dTransposeClampPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector stride, + std::vector padding, + std::vector output_padding, + std::vector dilation, + int64_t groups, + c10::optional output_min, + c10::optional output_max); + Tensor conv2d_clamp_run( const Tensor& input, const c10::intrusive_ptr& op_context); +Tensor conv2d_transpose_clamp_run( + const Tensor& input, + const c10::intrusive_ptr& op_context); + ContextConv2D create( const Tensor& weight, const c10::optional& bias, const IntArrayRef padding, + const IntArrayRef output_padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups, + const bool transposed, const float output_min, const float output_max); diff --git a/aten/src/ATen/native/xnnpack/Engine.h b/aten/src/ATen/native/xnnpack/Engine.h index 8c72038a48dccde..de3490cb3286e6b 100644 --- a/aten/src/ATen/native/xnnpack/Engine.h +++ b/aten/src/ATen/native/xnnpack/Engine.h @@ -17,7 +17,8 @@ bool use_convolution2d( const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, - const int64_t groups); + const int64_t groups, + const bool transposed); Tensor convolution2d( const Tensor& input, diff --git a/aten/src/ATen/native/xnnpack/OpContext.cpp b/aten/src/ATen/native/xnnpack/OpContext.cpp index 3ba8c318cbb91fc..fe78dcda1f99766 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.cpp +++ b/aten/src/ATen/native/xnnpack/OpContext.cpp @@ -48,13 +48,16 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight, weight, bias, padding, + {0, 0}, // output_padding stride, dilation, groups, + false, // transposed output_min ? output_min->to() : xnnpack::ContextConv2D::kMin, output_max ? output_max->to() : xnnpack::ContextConv2D::kMax); + auto conv2d_op_context = c10::make_intrusive( std::move(weight), @@ -66,6 +69,48 @@ XNNPackConv2dOpContext::create_context(at::Tensor&& weight, output_min, output_max, std::move(op_context)); + + return conv2d_op_context; +} + +c10::intrusive_ptr +XNNPackTransposeConv2dOpContext::create_context(at::Tensor&& weight, + c10::optional&& bias, + std::vector&& padding, + std::vector&& output_padding, + std::vector&& stride, + std::vector&& dilation, + int64_t groups, + const c10::optional output_min, + const c10::optional output_max) { + auto op_context = + xnnpack::internal::convolution2d::create( + weight, + bias, + padding, + output_padding, + stride, + dilation, + groups, + true, // transposed + output_min ? output_min->to() + : xnnpack::ContextConv2D::kMin, + output_max ? output_max->to() + : xnnpack::ContextConv2D::kMax); + + auto conv2d_op_context = + c10::make_intrusive( + std::move(weight), + std::move(bias), + std::move(padding), + std::move(output_padding), + std::move(stride), + std::move(dilation), + groups, + output_min, + output_max, + std::move(op_context)); + return conv2d_op_context; } @@ -73,6 +118,10 @@ Tensor XNNPackConv2dOpContext::run(const Tensor& input) { return xnnpack::internal::convolution2d::run(op_context_, input); } +Tensor XNNPackTransposeConv2dOpContext::run(const Tensor& input) { + return xnnpack::internal::convolution2d::run(op_context_, input); +} + } // namespace xnnpack } // namespace native } // namespace at diff --git a/aten/src/ATen/native/xnnpack/OpContext.h b/aten/src/ATen/native/xnnpack/OpContext.h index d46d5ee965bc434..e696ad3aa81dcdc 100644 --- a/aten/src/ATen/native/xnnpack/OpContext.h +++ b/aten/src/ATen/native/xnnpack/OpContext.h @@ -24,6 +24,18 @@ using SerializationTypeConv2dPrePack = std::tuple< int64_t, c10::optional, c10::optional>; +using SerializationTypeTransposeConv2dPrePack = std::tuple< + Tensor, + c10::optional, + std::vector, + std::vector, + std::vector, + std::vector, + int64_t, + c10::optional, + c10::optional>; + + class LinearOpContext : public torch::jit::CustomClassHolder { protected: @@ -94,6 +106,35 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder { virtual Tensor run(const Tensor& input) = 0; }; +class TransposeConv2dOpContext : public torch::jit::CustomClassHolder { + protected: + Tensor orig_weight_; + c10::optional orig_bias_; + std::vector stride_; + std::vector padding_; + std::vector output_padding_; + std::vector dilation_; + int64_t groups_; + c10::optional output_min_; + c10::optional output_max_; + + public: + SerializationTypeTransposeConv2dPrePack unpack() { + return std::make_tuple( + orig_weight_, + orig_bias_, + stride_, + padding_, + output_padding_, + dilation_, + groups_, + output_min_, + output_max_); + } + + virtual Tensor run(const Tensor& input) = 0; +}; + class XNNPackConv2dOpContext final : public Conv2dOpContext { private: ContextConv2D op_context_; @@ -120,7 +161,7 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext { output_max_ = max; } - Tensor run(const Tensor& input); + Tensor run(const Tensor& input) override; static c10::intrusive_ptr create_context( Tensor&& weight, @@ -132,6 +173,49 @@ class XNNPackConv2dOpContext final : public Conv2dOpContext { const c10::optional output_min, const c10::optional output_max); }; + +class XNNPackTransposeConv2dOpContext final : public TransposeConv2dOpContext { + private: + ContextConv2D op_context_; + + public: + XNNPackTransposeConv2dOpContext( + Tensor&& weight, + c10::optional&& bias, + std::vector&& padding, + std::vector&& output_padding, + std::vector&& stride, + std::vector&& dilation, + uint64_t groups, + c10::optional min, + c10::optional max, + ContextConv2D&& op_context) + : op_context_(std::move(op_context)) { + orig_weight_ = std::move(weight); + orig_bias_ = std::move(bias); + padding_ = std::move(padding); + output_padding_ = std::move(output_padding); + stride_ = std::move(stride); + dilation_ = std::move(dilation); + groups_ = groups; + output_min_ = min; + output_max_ = max; + } + + Tensor run(const Tensor& input) override; + + static c10::intrusive_ptr create_context( + Tensor&& weight, + c10::optional&& bias, + std::vector&& padding, + std::vector&& output_padding, + std::vector&& stride, + std::vector&& dilation, + int64_t groups, + const c10::optional output_min, + const c10::optional output_max); +}; + } // namespace xnnpack } // namespace native diff --git a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp index 55a67821533c637..e8442a64d0ad8ae 100644 --- a/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp +++ b/aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp @@ -13,6 +13,7 @@ namespace xnnpack { using internal::linear::createLinearClampPrePackOpContext; using internal::convolution2d::createConv2dClampPrePackOpContext; +using internal::convolution2d::createConv2dTransposeClampPrePackOpContext; TORCH_LIBRARY(xnnpack, m) { m.class_("LinearOpContext") @@ -48,20 +49,45 @@ TORCH_LIBRARY(xnnpack, m) { std::move(std::get<6>(state)), std::move(std::get<7>(state))); }); + + m.class_("TransposeConv2dOpContext") + .def_pickle( + [](const c10::intrusive_ptr& op_context) + -> SerializationTypeTransposeConv2dPrePack { // __getstate__ + return op_context->unpack(); + }, + [](SerializationTypeTransposeConv2dPrePack state) + -> c10::intrusive_ptr { // __setstate__ + return createConv2dTransposeClampPrePackOpContext( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state)), + std::move(std::get<6>(state)), + std::move(std::get<7>(state)), + std::move(std::get<8>(state))); + }); + } TORCH_LIBRARY(prepacked, m) { m.def("linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext"); m.def("linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y"); m.def("conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext"); + m.def("conv2d_transpose_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] output_padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.TransposeConv2dOpContext"); m.def("conv2d_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.Conv2dOpContext W_prepack) -> Tensor Y"); + m.def("conv2d_transpose_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.TransposeConv2dOpContext W_prepack) -> Tensor Y"); } TORCH_LIBRARY_IMPL(prepacked, CPU, m) { m.impl("linear_clamp_prepack", TORCH_FN(createLinearClampPrePackOpContext)); m.impl("linear_clamp_run", TORCH_FN(internal::linear::linear_clamp_run)); m.impl("conv2d_clamp_prepack", TORCH_FN(createConv2dClampPrePackOpContext)); + m.impl("conv2d_transpose_clamp_prepack", TORCH_FN(createConv2dTransposeClampPrePackOpContext)); m.impl("conv2d_clamp_run", TORCH_FN(internal::convolution2d::conv2d_clamp_run)); + m.impl("conv2d_transpose_clamp_run", TORCH_FN(internal::convolution2d::conv2d_transpose_clamp_run)); } } // namespace xnnpack diff --git a/aten/src/ATen/native/xnnpack/Shim.cpp b/aten/src/ATen/native/xnnpack/Shim.cpp index 326ebc6d200db89..e4b1f1cf667085d 100644 --- a/aten/src/ATen/native/xnnpack/Shim.cpp +++ b/aten/src/ATen/native/xnnpack/Shim.cpp @@ -35,7 +35,8 @@ bool use_convolution2d( const IntArrayRef, const IntArrayRef, const IntArrayRef, - const int64_t) { + const int64_t, + bool) { return false; } diff --git a/aten/src/ATen/native_parse.py b/aten/src/ATen/native_parse.py index 0a18cfd53e66aac..e8206e72141eb22 100644 --- a/aten/src/ATen/native_parse.py +++ b/aten/src/ATen/native_parse.py @@ -383,6 +383,11 @@ def parse_return_arguments(return_decl, inplace, func_decl): "Return Tensor of function \"{}\" flagged as inplace needs to be " \ "annotated as mutable".format(func_decl['func']) argument_dict['name'] = 'self' + elif t == "TensorList" and inplace: + assert annotation and annotation.endswith("!"), \ + "Return TensorList of function \"{}\" flagged as inplace needs to be " \ + "annotated as mutable".format(func_decl['func']) + argument_dict['name'] = 'self' else: argument_dict['name'] = 'result' if not multiple_args else 'result' + str(arg_idx) argument_dict['output'] = True diff --git a/aten/src/ATen/templates/TensorMethods.cpp b/aten/src/ATen/templates/TensorMethods.cpp index 1df7be9ff45c00f..064f5911cb1078a 100644 --- a/aten/src/ATen/templates/TensorMethods.cpp +++ b/aten/src/ATen/templates/TensorMethods.cpp @@ -12,15 +12,6 @@ #include #include -#ifdef USE_STATIC_DISPATCH -#include -#include -#include -#ifdef USE_VULKAN -#include -#endif -#endif - namespace at { Tensor Tensor::cpu() const { diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt index 0781eaad1efba2f..5b18eab3434cbfa 100644 --- a/aten/src/ATen/test/CMakeLists.txt +++ b/aten/src/ATen/test/CMakeLists.txt @@ -75,6 +75,7 @@ list(APPEND ATen_HIP_TEST_SRCS # ${CMAKE_CURRENT_SOURCE_DIR}/hip/hip_stream_test.cpp list(APPEND ATen_VULKAN_TEST_SRCS + ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_api_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/vulkan_test.cpp) list(APPEND ATen_MOBILE_TEST_SRCS diff --git a/aten/src/ATen/test/tensor_iterator_test.cpp b/aten/src/ATen/test/tensor_iterator_test.cpp index fcf7336c831f368..aecff33946f4f83 100644 --- a/aten/src/ATen/test/tensor_iterator_test.cpp +++ b/aten/src/ATen/test/tensor_iterator_test.cpp @@ -140,7 +140,7 @@ TEST(TensorIteratorTest, ComparisonLoopBinary_##name) { diff = in2.sub(in1); \ } \ auto expected = diff.clamp_min(0).to(kBool); \ - auto iter = TensorIterator::comparison_op(out, in1, in2, true); \ + auto iter = TensorIterator::comparison_op(out, in1, in2); \ at::native::cpu_serial_kernel(iter, [=](ctype a, ctype b) -> bool { return a < b; }); \ EXPECT_TRUE(out.equal(expected)); \ } diff --git a/aten/src/ATen/test/vulkan_api_test.cpp b/aten/src/ATen/test/vulkan_api_test.cpp new file mode 100644 index 000000000000000..28c1827485b7de9 --- /dev/null +++ b/aten/src/ATen/test/vulkan_api_test.cpp @@ -0,0 +1,16 @@ +#include + +#ifdef USE_VULKAN_API + +#include + +namespace { + +TEST(VulkanAPITest, Context) { + constexpr bool kDebug = true; + ASSERT_NO_THROW(at::native::vulkan::api::Context{kDebug}); +} + +} // namespace + +#endif /* USE_VULKAN_API */ diff --git a/benchmarks/static_runtime/deep_wide_pt.cc b/benchmarks/static_runtime/deep_wide_pt.cc new file mode 100644 index 000000000000000..6ce19abd8c84732 --- /dev/null +++ b/benchmarks/static_runtime/deep_wide_pt.cc @@ -0,0 +1,83 @@ +#include "deep_wide_pt.h" + +#include +#include + +namespace { +// No ReplaceNaN (this removes the constant in the model) +const std::string deep_wide_pt = R"JIT( +class DeepAndWide(Module): + __parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ] + __buffers__ = [] + _mu : Tensor + _sigma : Tensor + _fc_w : Tensor + _fc_b : Tensor + training : bool + def forward(self: __torch__.DeepAndWide, + ad_emb_packed: Tensor, + user_emb: Tensor, + wide: Tensor) -> Tensor: + _0 = self._fc_b + _1 = self._fc_w + _2 = self._sigma + wide_offset = torch.add(wide, self._mu, alpha=1) + wide_normalized = torch.mul(wide_offset, _2) + wide_preproc = torch.clamp(wide_normalized, 0., 10.) + user_emb_t = torch.transpose(user_emb, 1, 2) + dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t) + dp = torch.flatten(dp_unflatten, 1, -1) + input = torch.cat([dp, wide_preproc], 1) + fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1) + return torch.sigmoid(fc1) +)JIT"; + +const std::string trivial_model_1 = R"JIT( + def forward(self, a, b, c): + s = torch.tensor([[3, 3], [3, 3]]) + return a + b * c + s +)JIT"; + +void import_libs( + std::shared_ptr cu, + const std::string& class_name, + const std::shared_ptr& src, + const std::vector& tensor_table) { + torch::jit::SourceImporter si( + cu, + &tensor_table, + [&](const std::string& /* unused */) -> std::shared_ptr { + return src; + }, + /*version=*/2); + si.loadType(c10::QualifiedName(class_name)); +} +} // namespace + +torch::jit::Module getDeepAndWideSciptModel(int num_features) { + auto cu = std::make_shared(); + std::vector constantTable; + import_libs( + cu, + "__torch__.DeepAndWide", + std::make_shared(deep_wide_pt), + constantTable); + c10::QualifiedName base("__torch__"); + auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide")); + + torch::jit::Module mod(cu, clstype); + + mod.register_parameter("_mu", torch::randn({1, num_features}), false); + mod.register_parameter("_sigma", torch::randn({1, num_features}), false); + mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false); + mod.register_parameter("_fc_b", torch::randn({1}), false); + + // mod.dump(true, true, true); + return mod; +} + +torch::jit::Module getTrivialScriptModel() { + torch::jit::Module module("m"); + module.define(trivial_model_1); + return module; +} diff --git a/benchmarks/static_runtime/deep_wide_pt.h b/benchmarks/static_runtime/deep_wide_pt.h index 3208c4e9f2eabc6..f4f394c7ef630fa 100644 --- a/benchmarks/static_runtime/deep_wide_pt.h +++ b/benchmarks/static_runtime/deep_wide_pt.h @@ -1,7 +1,5 @@ #pragma once -#include -#include #include struct DeepAndWide : torch::nn::Module { @@ -33,69 +31,6 @@ struct DeepAndWide : torch::nn::Module { torch::Tensor mu_, sigma_, fc_w_, fc_b_; }; -namespace { -// No ReplaceNaN (this removes the constant in the model) -const std::string deep_wide_pt = R"JIT( -class DeepAndWide(Module): - __parameters__ = ["_mu", "_sigma", "_fc_w", "_fc_b", ] - __buffers__ = [] - _mu : Tensor - _sigma : Tensor - _fc_w : Tensor - _fc_b : Tensor - training : bool - def forward(self: __torch__.DeepAndWide, - ad_emb_packed: Tensor, - user_emb: Tensor, - wide: Tensor) -> Tensor: - _0 = self._fc_b - _1 = self._fc_w - _2 = self._sigma - wide_offset = torch.add(wide, self._mu, alpha=1) - wide_normalized = torch.mul(wide_offset, _2) - wide_preproc = torch.clamp(wide_normalized, 0., 10.) - user_emb_t = torch.transpose(user_emb, 1, 2) - dp_unflatten = torch.bmm(ad_emb_packed, user_emb_t) - dp = torch.flatten(dp_unflatten, 1, -1) - input = torch.cat([dp, wide_preproc], 1) - fc1 = torch.addmm(_0, input, torch.t(_1), beta=1, alpha=1) - return torch.sigmoid(fc1) -)JIT"; +torch::jit::Module getDeepAndWideSciptModel(int num_features = 50); -void import_libs( - std::shared_ptr cu, - const std::string& class_name, - const std::shared_ptr& src, - const std::vector& tensor_table) { - torch::jit::SourceImporter si( - cu, - &tensor_table, - [&](const std::string& name) -> std::shared_ptr { - return src; - }, - /*version=*/2); - si.loadType(c10::QualifiedName(class_name)); -} -} // namespace - -inline torch::jit::Module getDeepAndWideSciptModel(int num_features = 50) { - auto cu = std::make_shared(); - std::vector constantTable; - import_libs( - cu, - "__torch__.DeepAndWide", - std::make_shared(deep_wide_pt), - constantTable); - c10::QualifiedName base("__torch__"); - auto clstype = cu->get_class(c10::QualifiedName(base, "DeepAndWide")); - - torch::jit::Module mod(cu, clstype); - - mod.register_parameter("_mu", torch::randn({1, num_features}), false); - mod.register_parameter("_sigma", torch::randn({1, num_features}), false); - mod.register_parameter("_fc_w", torch::randn({1, num_features + 1}), false); - mod.register_parameter("_fc_b", torch::randn({1}), false); - - // mod.dump(true, true, true); - return mod; -} +torch::jit::Module getTrivialScriptModel(); diff --git a/benchmarks/static_runtime/deep_wide_pt_bench.cc b/benchmarks/static_runtime/deep_wide_pt_bench.cc index c3334c289992463..ef960d28d7eb88a 100644 --- a/benchmarks/static_runtime/deep_wide_pt_bench.cc +++ b/benchmarks/static_runtime/deep_wide_pt_bench.cc @@ -1,6 +1,5 @@ #include #include - #include "deep_wide_pt.h" const int embedding_size = 32; diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc new file mode 100644 index 000000000000000..3ad0956ced737a6 --- /dev/null +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -0,0 +1,20 @@ +#include +#include +#include "deep_wide_pt.h" + +TEST(StaticRuntime, TrivialModel) { + torch::jit::Module mod = getTrivialScriptModel(); + auto a = torch::randn({2, 2}); + auto b = torch::randn({2, 2}); + auto c = torch::randn({2, 2}); + + // run jit graph executor + std::vector input_ivalues({a, b, c}); + at::Tensor output_1 = mod.forward(input_ivalues).toTensor(); + + // run static runtime + std::vector input_tensors({a, b, c}); + torch::jit::StaticRuntime runtime(mod); + at::Tensor output_2 = runtime.run(input_tensors)[0]; + EXPECT_TRUE(output_1.equal(output_2)); +} diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc index 94293baeedd6458..4fb3044a031acaa 100644 --- a/binaries/optimize_for_mobile.cc +++ b/binaries/optimize_for_mobile.cc @@ -16,54 +16,71 @@ #include +#include "torch/script.h" #include "torch/csrc/jit/api/module.h" #include "torch/csrc/jit/passes/vulkan_rewrite.h" #include "torch/csrc/jit/passes/xnnpack_rewrite.h" #include "torch/csrc/jit/serialization/import.h" +#include "torch/csrc/jit/serialization/export.h" -C10_DEFINE_string(model, "", "The given torch script model to transform."); +C10_DEFINE_string(model, "", "The torch script model to optimize."); C10_DEFINE_string( output, "", "Name of the output model to be saved."); -C10_DEFINE_bool( - save_for_mobile, - false, - "Save the model with bytecode format compatible with lite inteprter."); -C10_DEFINE_bool(vulkan, false, "Vulkan optimize_for_mobile"); +C10_DEFINE_string(backend, "", "The backend to be optimized"); int main(int argc, char** argv) { c10::SetUsageMessage( - "Run speed benchmark for pytorch model.\n" - "Example usage:\n" + "\nRun optimization pass for pytorch model. Example usage:\n" "./optimize_for_mobile" " --model=" - " --output="); + " [--output=]" + " [--backend=]" + ); + if (!c10::ParseCommandLineFlags(&argc, &argv)) { std::cerr << "Failed to parse command line flags!" << std::endl; + std::cout << c10::UsageMessage() << std::endl; return 1; } - CAFFE_ENFORCE(FLAGS_model != "", "Valid input must be provided."); + CAFFE_ENFORCE(FLAGS_model != "", c10::UsageMessage()); std::string output_model_name = - FLAGS_model.substr(0, FLAGS_model.find(".")) + "_mobile_optimized.pt"; + FLAGS_model.substr(0, FLAGS_model.find(".")) + "_optimized.bc"; if (FLAGS_output != "") { output_model_name = FLAGS_output; } auto module = torch::jit::load(FLAGS_model); + auto ops = torch::jit::export_opnames(module); + std::cout << "\npt_operator_library(" << std::endl; + std::cout << "\tname = \"old_op_library\"," << std::endl; + std::cout << "\tops = [" << std::endl; + for (auto const& op: ops) { + std::cout << "\t\t\"" << op << "\"," << std::endl; + } + std::cout << "\t],\n)\n" << std::endl; - auto optimized_module = FLAGS_vulkan - ? torch::jit::vulkanOptimizeForMobile(module) - : torch::jit::optimizeForMobile(module); - - if (FLAGS_save_for_mobile) { - optimized_module._save_for_mobile(output_model_name); + torch::jit::Module optimized_module; + if (FLAGS_backend == "" || FLAGS_backend == "cpu") { + optimized_module = torch::jit::optimizeForMobile(module); + } else if (FLAGS_backend == "vulkan") { + optimized_module = torch::jit::vulkanOptimizeForMobile(module); } else { - optimized_module.save(output_model_name); + CAFFE_ENFORCE(false, "Unknown backend: " + FLAGS_backend); } - + auto new_ops = torch::jit::export_opnames(optimized_module); + std::cout << "\npt_operator_library(" << std::endl; + std::cout << "\tname = \"new_op_library\"," << std::endl; + std::cout << "\tops = [" << std::endl; + for (auto const& op: new_ops) { + std::cout << "\t\t\"" << op << "\"," << std::endl; + } + std::cout << "\t],\n)\n" << std::endl; + optimized_module._save_for_mobile(output_model_name); + std::cout << "The optimized model for lite interpreter was saved to " << output_model_name << std::endl; return 0; } diff --git a/c10/core/CPUCachingAllocator.cpp b/c10/core/CPUCachingAllocator.cpp index 2ef63cdb25e6181..232b8f2306e24b4 100644 --- a/c10/core/CPUCachingAllocator.cpp +++ b/c10/core/CPUCachingAllocator.cpp @@ -95,6 +95,7 @@ CPUCachingAllocator* GetThreadLocalCachingAllocator() { WithCPUCachingAllocatorGuard::WithCPUCachingAllocatorGuard( CPUCachingAllocator* allocator) { + caching_allocator_ptr = allocator; prev_caching_allocator_ptr_ = GetThreadLocalCachingAllocator(); } diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 82a3f7151151f71..9738f47ad175a8b 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -35,7 +35,7 @@ class C10_API Scalar { #undef DEFINE_IMPLICIT_CTOR // Value* is both implicitly convertible to SymbolicVariable and bool which - // causes ambiguosity error. Specialized constructor for bool resolves this + // causes ambiguity error. Specialized constructor for bool resolves this // problem. template < typename T, diff --git a/c10/macros/cmake_macros.h.in b/c10/macros/cmake_macros.h.in index ee192e12e1421c5..5e42506f20dcbc4 100644 --- a/c10/macros/cmake_macros.h.in +++ b/c10/macros/cmake_macros.h.in @@ -14,9 +14,4 @@ // to converging libtorch and caffe2 mobile builds and removing it eventually. #cmakedefine FEATURE_TORCH_MOBILE -// If defined it will use static dispatch for ATen operators. -// Should expose this macro for projects including ATen headers to inherient -// the same option. -#cmakedefine USE_STATIC_DISPATCH - #endif // C10_MACROS_CMAKE_MACROS_H_ diff --git a/torch/csrc/utils/hash.h b/c10/util/hash.h similarity index 71% rename from torch/csrc/utils/hash.h rename to c10/util/hash.h index 954a7b5b7d08140..4ddef3e564fc75a 100644 --- a/torch/csrc/utils/hash.h +++ b/c10/util/hash.h @@ -3,7 +3,7 @@ #include #include -namespace torch { +namespace c10 { // NOTE: hash_combine is based on implementation from Boost // @@ -36,35 +36,38 @@ inline size_t hash_combine(size_t seed, size_t value) { } //////////////////////////////////////////////////////////////////////////////// -// torch::hash implementation +// c10::hash implementation //////////////////////////////////////////////////////////////////////////////// namespace _hash_detail { -// Use template argument deduction to shorten calls to torch::hash -template +// Use template argument deduction to shorten calls to c10::hash +template size_t simple_get_hash(const T& o); -template -using type_if_not_enum = typename std::enable_if::value, V>::type; - -// Use SFINAE to dispatch to std::hash if possible, cast enum types to int automatically, -// and fall back to T::hash otherwise. -// NOTE: C++14 added support for hashing enum types to the standard, and some compilers -// implement it even when C++14 flags aren't specified. This is why we have to disable -// this overload if T is an enum type (and use the one below in this case). -template -auto dispatch_hash(const T& o) -> decltype(std::hash()(o), type_if_not_enum()) { +template +using type_if_not_enum = + typename std::enable_if::value, V>::type; + +// Use SFINAE to dispatch to std::hash if possible, cast enum types to int +// automatically, and fall back to T::hash otherwise. NOTE: C++14 added support +// for hashing enum types to the standard, and some compilers implement it even +// when C++14 flags aren't specified. This is why we have to disable this +// overload if T is an enum type (and use the one below in this case). +template +auto dispatch_hash(const T& o) + -> decltype(std::hash()(o), type_if_not_enum()) { return std::hash()(o); } -template -typename std::enable_if::value, size_t>::type dispatch_hash(const T& o) { +template +typename std::enable_if::value, size_t>::type dispatch_hash( + const T& o) { using R = typename std::underlying_type::type; return std::hash()(static_cast(o)); } -template +template auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) { return T::hash(o); } @@ -72,7 +75,7 @@ auto dispatch_hash(const T& o) -> decltype(T::hash(o), size_t()) { } // namespace _hash_detail // Hasher struct -template +template struct hash { size_t operator()(const T& o) const { return _hash_detail::dispatch_hash(o); @@ -80,17 +83,18 @@ struct hash { }; // Specialization for std::tuple -template +template struct hash> { - template + template struct tuple_hash { size_t operator()(const std::tuple& t) const { - return hash_combine(_hash_detail::simple_get_hash(std::get(t)), - tuple_hash()(t)); + return hash_combine( + _hash_detail::simple_get_hash(std::get(t)), + tuple_hash()(t)); } }; - template + template struct tuple_hash<0, Ts...> { size_t operator()(const std::tuple& t) const { return _hash_detail::simple_get_hash(std::get<0>(t)); @@ -98,16 +102,16 @@ struct hash> { }; size_t operator()(const std::tuple& t) const { - return tuple_hash()(t); + return tuple_hash()(t); } }; // Specialization for std::vector -template +template struct hash> { size_t operator()(const std::vector& v) const { size_t seed = 0; - for (const auto & elem : v) { + for (const auto& elem : v) { seed = hash_combine(seed, _hash_detail::simple_get_hash(elem)); } return seed; @@ -116,23 +120,23 @@ struct hash> { namespace _hash_detail { -template +template size_t simple_get_hash(const T& o) { - return torch::hash()(o); + return c10::hash()(o); } } // namespace _hash_detail // Use this function to actually hash multiple things in one line. -// Dispatches to torch::hash, so it can hash containers. +// Dispatches to c10::hash, so it can hash containers. // Example: // // static size_t hash(const MyStruct& s) { // return get_hash(s.member1, s.member2, s.member3); // } -template +template size_t get_hash(const Types&... args) { - return torch::hash()(std::tie(args...)); + return c10::hash()(std::tie(args...)); } -} // namespace torch +} // namespace c10 diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 541ad9e29b40e39..60525f0f4be6872 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -79,5 +79,4 @@ static_assert( {"USE_MKLDNN", "${CAFFE2_USE_MKLDNN}"}, \ {"USE_NVTX", "${CAFFE2_USE_NVTX}"}, \ {"USE_TRT", "${CAFFE2_USE_TRT}"}, \ - {"USE_STATIC_DISPATCH", "${USE_STATIC_DISPATCH}"}, \ } diff --git a/caffe2/python/operator_test/arg_ops_test.py b/caffe2/python/operator_test/arg_ops_test.py index 4e22dbc597d97c6..ce800636e6e6c82 100644 --- a/caffe2/python/operator_test/arg_ops_test.py +++ b/caffe2/python/operator_test/arg_ops_test.py @@ -16,7 +16,7 @@ class TestArgOps(serial.SerializedTestCase): @given( X=hu.tensor(dtype=np.float32), axis=st.integers(-1, 5), keepdims=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_argmax(self, X, axis, keepdims, gc, dc): if axis >= len(X.shape): axis %= len(X.shape) @@ -38,7 +38,7 @@ def argmax_ref(X): @given( X=hu.tensor(dtype=np.float32), axis=st.integers(-1, 5), keepdims=st.booleans(), **hu.gcs) - @settings(deadline=1000) + @settings(deadline=None) def test_argmin(self, X, axis, keepdims, gc, dc): if axis >= len(X.shape): axis %= len(X.shape) diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index b82529a9fd26c5c..9116dd2e317647a 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -167,8 +167,9 @@ if(INTERN_BUILD_ATEN_OPS) endif() if(SELECTED_OP_LIST) - if(NOT USE_STATIC_DISPATCH AND NOT OP_DEPENDENCY) - message(FATAL_ERROR "Must provide op dependency graph .yaml file for custom build with dynamic dispatch!") + if(NOT OP_DEPENDENCY) + message(INFO "Use default op dependency graph .yaml file for custom build with dynamic dispatch.") + set(OP_DEPENDENCY ${CMAKE_CURRENT_LIST_DIR}/../tools/code_analyzer/default_op_deps.yaml) endif() execute_process( COMMAND diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index c8750a272fe6fa7..6cd9429b43e9b8c 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -21,7 +21,6 @@ function(caffe2_print_configuration_summary) message(STATUS " TORCH_VERSION : ${TORCH_VERSION}") message(STATUS " CAFFE2_VERSION : ${CAFFE2_VERSION}") message(STATUS " BUILD_CAFFE2_MOBILE : ${BUILD_CAFFE2_MOBILE}") - message(STATUS " USE_STATIC_DISPATCH : ${USE_STATIC_DISPATCH}") message(STATUS " BUILD_BINARY : ${BUILD_BINARY}") message(STATUS " BUILD_CUSTOM_PROTOBUF : ${BUILD_CUSTOM_PROTOBUF}") if(${CAFFE2_LINK_LOCAL_PROTOBUF}) diff --git a/codecov.yml b/codecov.yml index 4142a3825659c19..79a3cd8057b194a 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,2 +1,7 @@ +coverage: + status: + project: + default: + threshold: 1% fixes: - "/opt/conda/lib/python3.8/site-packages/::project/" diff --git a/docs/source/linalg.rst b/docs/source/linalg.rst index 1152267f3609d11..834b6a60ac93fd2 100644 --- a/docs/source/linalg.rst +++ b/docs/source/linalg.rst @@ -13,3 +13,4 @@ Functions --------- .. autofunction:: det +.. autofunction:: norm diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index ad859bdf8e4b68a..c6a7f29036ee4a9 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -518,6 +518,8 @@ view of a storage and defines numeric operations on it. .. automethod:: sinh_ .. automethod:: asinh .. automethod:: asinh_ + .. automethod:: arcsinh + .. automethod:: arcsinh_ .. automethod:: size .. automethod:: slogdet .. automethod:: solve @@ -539,6 +541,8 @@ view of a storage and defines numeric operations on it. .. automethod:: stride .. automethod:: sub .. automethod:: sub_ + .. automethod:: subtract + .. automethod:: subtract_ .. automethod:: sum .. automethod:: sum_to_size .. automethod:: svd @@ -554,6 +558,8 @@ view of a storage and defines numeric operations on it. .. automethod:: tanh_ .. automethod:: atanh .. automethod:: atanh_ + .. automethod:: arctanh + .. automethod:: arctanh_ .. automethod:: tolist .. automethod:: topk .. automethod:: to_sparse diff --git a/docs/source/torch.rst b/docs/source/torch.rst index e63563c178bde8d..7ab9b5a61c646e2 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -261,9 +261,11 @@ Pointwise Ops asin arcsin asinh + arcsinh atan arctan atanh + arctanh atan2 bitwise_not bitwise_and @@ -323,6 +325,8 @@ Pointwise Ops sinh sqrt square + sub + subtract tan tanh true_divide diff --git a/mode/aibench_caffe2_android b/mode/aibench_caffe2_android new file mode 100644 index 000000000000000..814881179c40393 --- /dev/null +++ b/mode/aibench_caffe2_android @@ -0,0 +1,3 @@ +--config +caffe2.strip_glog=0 +@fbsource//fbandroid/mode/ndk_libcxx diff --git a/mode/aibench_pytorch_android b/mode/aibench_pytorch_android new file mode 100644 index 000000000000000..2572d24d3032e5a --- /dev/null +++ b/mode/aibench_pytorch_android @@ -0,0 +1,5 @@ +--config +user.ndk_cxxflags='-g1' +--config +pt.disable_per_op_profiling=0 +@fbsource//fbandroid/mode/ndk_libcxx diff --git a/mypy.ini b/mypy.ini index ce6bd2ab3c57a02..d2765089197c6b0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -60,9 +60,6 @@ ignore_errors = True [mypy-torch.quantization.default_mappings] ignore_errors = True -[mypy-torch.quantization.fuse_modules] -ignore_errors = True - [mypy-torch.quantization.observer] ignore_errors = True diff --git a/scripts/build_android.sh b/scripts/build_android.sh index dc0101158e5b120..21a0602990b6198 100755 --- a/scripts/build_android.sh +++ b/scripts/build_android.sh @@ -61,7 +61,6 @@ CMAKE_ARGS=() if [ -z "${BUILD_CAFFE2_MOBILE:-}" ]; then # Build PyTorch mobile - CMAKE_ARGS+=("-DUSE_STATIC_DISPATCH=ON") CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$($PYTHON -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$($PYTHON -c 'import sys; print(sys.executable)')") CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") diff --git a/scripts/build_ios.sh b/scripts/build_ios.sh index aebddac5cce8e18..75ee1ed2001a150 100755 --- a/scripts/build_ios.sh +++ b/scripts/build_ios.sh @@ -13,7 +13,6 @@ CMAKE_ARGS=() if [ -z "${BUILD_CAFFE2_MOBILE:-}" ]; then # Build PyTorch mobile - CMAKE_ARGS+=("-DUSE_STATIC_DISPATCH=ON") CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") @@ -62,7 +61,7 @@ fi # IOS_PLATFORM controls type of iOS platform (see ios-cmake) if [ -n "${IOS_PLATFORM:-}" ]; then - CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") + CMAKE_ARGS+=("-DIOS_PLATFORM=${IOS_PLATFORM}") if [ "${IOS_PLATFORM}" == "WATCHOS" ]; then # enable bitcode by default for watchos CMAKE_ARGS+=("-DCMAKE_C_FLAGS=-fembed-bitcode") diff --git a/scripts/build_mobile.sh b/scripts/build_mobile.sh index c468a62baa9e1b0..c413f86e4d67b51 100755 --- a/scripts/build_mobile.sh +++ b/scripts/build_mobile.sh @@ -15,7 +15,6 @@ echo "Bash: $(/bin/bash --version | head -1)" echo "Caffe2 path: $CAFFE2_ROOT" CMAKE_ARGS=() -CMAKE_ARGS+=("-DUSE_STATIC_DISPATCH=ON") CMAKE_ARGS+=("-DCMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')") CMAKE_ARGS+=("-DPYTHON_EXECUTABLE=$(python -c 'import sys; print(sys.executable)')") CMAKE_ARGS+=("-DBUILD_CUSTOM_PROTOBUF=OFF") diff --git a/setup.py b/setup.py index 4027255ae9e9cee..508dcdd94e9f3fa 100644 --- a/setup.py +++ b/setup.py @@ -171,6 +171,17 @@ print("32-bit Windows Python runtime is not supported. Please switch to 64-bit Python.") sys.exit(-1) +import platform +python_min_version = (3, 6, 1) +python_min_version_str = '.'.join((str(num) for num in python_min_version)) +python_max_version = (3, 9, 0) +python_max_version_str = '.'.join((str(num) for num in python_max_version)) +if sys.version_info < python_min_version or sys.version_info >= python_max_version: + print("You are using Python {}. Python >={},<{} is required.".format(platform.python_version(), + python_min_version_str, + python_max_version_str)) + sys.exit(-1) + from setuptools import setup, Extension, distutils, find_packages from collections import defaultdict from distutils import core @@ -883,7 +894,7 @@ def print_box(msg): download_url='https://github.com/pytorch/pytorch/tags', author='PyTorch Team', author_email='packages@pytorch.org', - python_requires='>=3.6.1', + python_requires='>={},<{}'.format(python_min_version_str, python_max_version_str), # PyPI package information. classifiers=[ 'Development Status :: 5 - Production/Stable', @@ -891,18 +902,15 @@ def print_box(msg): 'Intended Audience :: Education', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: BSD License', - 'Programming Language :: C++', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', - ], + 'Programming Language :: C++', + 'Programming Language :: Python :: 3', + ] + ['Programming Language :: Python :: 3.{}' for i in range(python_min_version[1], python_max_version[1])], license='BSD-3', keywords='pytorch machine learning', ) diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py index cef30ef14fdbb48..a461d6247ca4c08 100644 --- a/test/backward_compatibility/check_backward_compatibility.py +++ b/test/backward_compatibility/check_backward_compatibility.py @@ -18,6 +18,13 @@ # believe you can land your diff before then. # # Allowlist entries can be removed after the date listed on them passes. +# +# Allowlist item format: +# [ +# 0: function name regex +# 1: date until which the allowlist entry is valid +# 2: (optional) function argument regex +# ] allow_list = [ ("c10_experimental", datetime.date(2222, 1, 1)), # We export some functions and classes for test_jit.py directly from libtorch.so, @@ -60,6 +67,7 @@ ("aten::linalg_outer", datetime.date(2020, 8, 30)), ("aten::linalg_outer.out", datetime.date(2020, 8, 30)), ("aten::_compute_linear_combination", datetime.date(2020, 9, 1)), + ("__getstate__", datetime.date(2020, 9, 1), "Conv[23]dPackedParams"), ] @@ -69,6 +77,10 @@ def allow_listed(schema, allow_list): continue regexp = re.compile(item[0]) if regexp.search(schema.name): + if len(item) > 2: + # if arguments regex is present, use it + regexp_args = re.compile(item[2]) + return bool(regexp_args.search(str(schema))) return True return False diff --git a/test/cpp/api/fft.cpp b/test/cpp/api/fft.cpp index 7deae25b8127796..3a62b9eea8b5d4d 100644 --- a/test/cpp/api/fft.cpp +++ b/test/cpp/api/fft.cpp @@ -1,9 +1,18 @@ #include #include +#include + + +// Tests that the fft function can be called as usual +TEST(FFTTest, unclobbered_fft) { + auto t = torch::randn({64, 2}, torch::dtype(torch::kDouble)); + torch::fft(t, 1); +} + +// Clobbers torch::fft the function with torch::fft the namespace #include -#include // NOTE: Visual Studio and ROCm builds don't understand complex literals // as of August 2020 diff --git a/test/cpp/api/transformer.cpp b/test/cpp/api/transformer.cpp index 2383b76d4dca7d5..59ab0d0f5e1f483 100644 --- a/test/cpp/api/transformer.cpp +++ b/test/cpp/api/transformer.cpp @@ -163,14 +163,17 @@ void transformer_decoder_layer_test_helper(bool is_cuda){ torch::TensorOptions tensor_options = torch::TensorOptions() .dtype(torch::kFloat32).device(device); - TransformerDecoderLayer model = - get_a_test_layer(tensor_options); + TransformerDecoderLayer model = get_a_test_layer< + TransformerDecoderLayer, + TransformerDecoderLayerOptions>(tensor_options); // deterministic input - at::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); - at::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); - at::Tensor result = model(decoder_input, memory_input).detach(); - at::Tensor ref_output = torch::tensor( + torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, + tensor_options); + torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, + tensor_options); + torch::Tensor result = model(decoder_input, memory_input).detach(); + torch::Tensor ref_output = torch::tensor( {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options); ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); @@ -235,9 +238,9 @@ void transformer_decoder_layer_test_helper(bool is_cuda){ /*equal_nan=*/true)); // key_padding_mask - at::Tensor t_mask = {}; - at::Tensor m_mask = {}; - at::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; + torch::Tensor t_mask = {}; + torch::Tensor m_mask = {}; + torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask).detach(); ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, @@ -269,7 +272,7 @@ void transformer_decoder_layer_test_helper(bool is_cuda){ /*equal_nan=*/true)); // memory_key_padding_mask - at::Tensor t_key_padding_mask = {}; + torch::Tensor t_key_padding_mask = {}; key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; result = model(decoder_input, memory_input, t_mask, m_mask, t_key_padding_mask, key_padding_mask).detach(); @@ -317,17 +320,18 @@ void transformer_decoder_layer_test_helper_gelu(bool is_cuda) { torch::TensorOptions tensor_options = torch::TensorOptions() .dtype(torch::kFloat32).device(device); - TransformerDecoderLayer model = - get_a_test_layer(tensor_options); + TransformerDecoderLayer model = get_a_test_layer< + TransformerDecoderLayer, + TransformerDecoderLayerOptions>(tensor_options); model.get()->options.activation(torch::kGELU); // deterministic input - at::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, + torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); - at::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, + torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); - at::Tensor result = model(decoder_input, memory_input).detach(); - at::Tensor ref_output = torch::tensor( + torch::Tensor result = model(decoder_input, memory_input).detach(); + torch::Tensor ref_output = torch::tensor( {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options); ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); @@ -597,3 +601,490 @@ TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) { " (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n" ")"); } + +void transformer_decoder_test_helper(bool is_cuda) { + // this is a deterministic test for TransformerDecoder + torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU; + torch::TensorOptions tensor_options = + torch::TensorOptions().dtype(torch::kFloat32).device(device); + + TransformerDecoderLayer decoder_layer = get_a_test_layer< + TransformerDecoderLayer, + TransformerDecoderLayerOptions>(tensor_options); + + TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1)); + if (is_cuda) { + model->to(torch::kCUDA); + } + + + torch::Tensor decoder_input = torch::tensor({{{20, 30, 40, 50}}}, + tensor_options); + torch::Tensor memory_input = torch::tensor({{{60, 70, 80, 90}}}, + tensor_options); + torch::Tensor result = model(decoder_input, memory_input).detach(); + torch::Tensor ref_output = torch::tensor( + {{{2.314351, 0.094805, -0.671322, 0.101977}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + +// deterministic input + decoder_input = torch::tensor({{{9, 10, 11, 12}}, + {{11, 12, 13, 14}}}, tensor_options); + memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.422245, 0.051716, -0.606338, -0.024756}}, + {{2.422245, 0.051716, -0.606338, -0.024756}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor({{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}}, tensor_options); + memory_input = torch::tensor({{{9, 10, 11, 12}}, + {{11, 12, 13, 14}}}, tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.343536, 0.085561, -0.654954, 0.074991}}, + {{2.343536, 0.085561, -0.654954, 0.074991}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + + // deterministic input + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // key_padding_mask + torch::Tensor t_mask = {}; + torch::Tensor m_mask = {}; + torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1; + result = model(decoder_input, memory_input, t_mask, m_mask, + key_padding_mask).detach(); + ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // key_padding_mask + key_padding_mask[0][2] = 1; + key_padding_mask[1][1] = 1; + key_padding_mask[1][2] = 1; + result = model(decoder_input, memory_input, t_mask, m_mask, + key_padding_mask).detach(); + ref_output = torch::tensor({{{2.430025, 0.027643, -0.601164, -0.073476}, + {2.4323, 0.029375, -0.599553, -0.071881}}, + {{2.428523, 0.026838, -0.602226, -0.07391}, + {2.432634, 0.029842, -0.599318, -0.071253}}, + {{2.432278, 0.028152, -0.599555, -0.074139}, + {2.432659, 0.029244, -0.599294, -0.072382}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // memory_key_padding_mask + torch::Tensor t_key_padding_mask = {}; + key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1; + result = model(decoder_input, memory_input, t_mask, m_mask, + t_key_padding_mask, key_padding_mask).detach(); + ref_output = torch::tensor({{{2.430065, 0.027862, -0.601136, -0.073096}, + {2.431935, 0.028907, -0.599809, -0.072488}}, + {{2.428457, 0.027053, -0.602275, -0.073462}, + {2.431970, 0.029387, -0.599789, -0.071621}}, + {{2.431934, 0.028196, -0.599802, -0.073809}, + {2.432306, 0.028858, -0.599542, -0.072846}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // memory_key_padding_mask + key_padding_mask[0][4] = 1; + key_padding_mask[1][3] = 1; + key_padding_mask[1][4] = 1; + result = model(decoder_input, memory_input, t_mask, m_mask, + t_key_padding_mask, key_padding_mask).detach(); + ref_output = torch::tensor({{{2.429757, 0.027358, -0.601351, -0.073816}, + {2.432692, 0.028583, -0.599263, -0.073634}}, + {{2.428247, 0.02662, -0.602419, -0.074123}, + {2.432657, 0.029055, -0.599293, -0.072732}}, + {{2.431515, 0.027687, -0.600096, -0.074459}, + {2.433075, 0.028543, -0.598987, -0.073985}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // multiple layers no norm + model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2)); + if (is_cuda) { + model->to(torch::kCUDA); + } + + decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor( + {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // multiple layers no norm + model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); + if (is_cuda) { + model->to(torch::kCUDA); + } + // deterministic input + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}, + {{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}, + {{2.42794, 0.026164, -0.60263, -0.0747591}, + {2.43113, 0.0279516, -0.600376, -0.0736896}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + + // multiple layers with norm + LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()})); + model = TransformerDecoder( + TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm))); + if (is_cuda) { + model->to(torch::kCUDA); + } + + decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options); + memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor( + {{{1.66166, -0.326986, -1.01466, -0.320017}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // multiple layers with norm + model = TransformerDecoder( + TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm))); + if (is_cuda) { + model->to(torch::kCUDA); + } + // deterministic input + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}, + {{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}, + {{1.69559, -0.357291, -0.894741, -0.443553}, + {1.69571, -0.357363, -0.894154, -0.444196}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + //gelu activation test cases + decoder_layer.get()->options.activation(torch::kGELU); + model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1)); + if (is_cuda) { + model->to(torch::kCUDA); + } + + // deterministic input + decoder_input = torch::tensor({{{20, 30, 40, 50}}}, + tensor_options); + memory_input = torch::tensor({{{60, 70, 80, 90}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor( + {{{2.306435, 0.095946, -0.675796, 0.10687}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor({{{9, 10, 11, 12}}, + {{11, 12, 13, 14}}}, + tensor_options); + memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.415448, 0.054389, -0.610932, -0.0156613}}, + {{2.415448, 0.054389, -0.610932, -0.0156613}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor({{{1, 2, 3, 4}}, + {{5, 6, 7, 8}}}, + tensor_options); + memory_input = torch::tensor({{{9, 10, 11, 12}}, + {{11, 12, 13, 14}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.338531, 0.087709, -0.65776, 0.080646}}, + {{2.338531, 0.087709, -0.65776, 0.080646}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // deterministic input + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor( + {{{2.42049104, 0.03443088, -0.60793706, -0.05436271}, + {2.42210631, 0.03546578, -0.60679895, -0.05357488}}, + {{2.41907674, 0.0336104, -0.60892977, -0.05490462}, + {2.42216881, 0.03586554, -0.6067524, -0.05289126}}, + {{2.42205716, 0.03488046, -0.60683681, -0.05460596}, + {2.42240309, 0.0354595, -0.60659063, -0.05378816}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // Multiple layers no norm + model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6)); + if (is_cuda) { + model->to(torch::kCUDA); + } + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}, + {{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}, + {{2.41859, 0.0328114, -0.609269, -0.0560386}, + {2.42138, 0.034598, -0.607316, -0.0546574}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + + // Multiple layers with norm + norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()})); + model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6) + .norm(AnyModule(norm))); + if (is_cuda) { + model->to(torch::kCUDA); + } + + decoder_input = torch::tensor({{{0.4517, 0.6793, 0.5313, 0.0034}, + {0.2678, 0.3677, 0.4459, 0.7166}}, + {{0.8100, 0.3716, 0.4096, 0.1976}, + {0.6958, 0.8844, 0.6081, 0.8315}}, + {{0.0494, 0.9343, 0.5955, 0.3830}, + {0.5404, 0.3464, 0.9378, 0.6200}}}, + tensor_options); + memory_input = torch::tensor({{{0.7462, 0.6653, 0.5679, 0.4891}, + {0.5387, 0.1655, 0.3565, 0.0471}}, + {{0.8335, 0.2799, 0.5031, 0.2947}, + {0.1402, 0.0318, 0.7636, 0.1346}}, + {{0.6333, 0.9344, 0.1376, 0.9938}, + {0.8924, 0.2872, 0.6692, 0.2944}}, + {{0.9897, 0.6915, 0.3154, 0.1733}, + {0.8645, 0.3513, 0.3064, 0.0767}}, + {{0.8117, 0.2366, 0.4838, 0.7881}, + {0.3718, 0.4945, 0.9511, 0.0864}}}, + tensor_options); + result = model(decoder_input, memory_input).detach(); + ref_output = torch::tensor({{{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}, + {{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}, + {{1.69298, -0.355163, -0.906375, -0.431439}, + {1.69305, -0.355195, -0.906062, -0.431791}}}, + tensor_options); + ASSERT_EQ(result.sizes().size(),ref_output.sizes().size()); + ASSERT_TRUE(torch::allclose(result, ref_output, 1e-7, 1e-5, + /*equal_nan=*/true)); + +} + +TEST_F(TransformerTest, TransformerDecoder) { + transformer_decoder_test_helper(false); +} + +TEST_F(TransformerTest, TransformerDecoder_CUDA) { + transformer_decoder_test_helper(true); +} + + +TEST_F(TransformerTest, PrettyPrintTransformerDecoder) { + LayerNorm norm = LayerNorm(LayerNormOptions({4})); + TransformerDecoderOptions options( + TransformerDecoderOptions( + TransformerDecoderLayerOptions(4, 2),2).norm(AnyModule(norm))); + ASSERT_EQ( + c10::str(TransformerDecoder(options)), + "torch::nn::TransformerDecoderImpl(\n" + " (layers): torch::nn::ModuleList(\n" + " (0): torch::nn::TransformerDecoderLayerImpl(\n" + " (self_attn): torch::nn::MultiheadAttention(\n" + " (out_proj): torch::nn::Linear(in_features=4, out_features=4," + " bias=true)\n" + " )\n" + " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm1): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " (multihead_attn): torch::nn::MultiheadAttention(\n" + " (out_proj): torch::nn::Linear(in_features=4, out_features=4," + " bias=true)\n" + " )\n" + " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm2): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " (linear1): torch::nn::Linear(in_features=4, out_features=2048," + " bias=true)\n" + " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (linear2): torch::nn::Linear(in_features=2048, out_features=4," + " bias=true)\n" + " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm3): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " )\n" + " (1): torch::nn::TransformerDecoderLayerImpl(\n" + " (self_attn): torch::nn::MultiheadAttention(\n" + " (out_proj): torch::nn::Linear(in_features=4, out_features=4," + " bias=true)\n" + " )\n" + " (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm1): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " (multihead_attn): torch::nn::MultiheadAttention(\n" + " (out_proj): torch::nn::Linear(in_features=4, out_features=4," + " bias=true)\n" + " )\n" + " (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm2): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " (linear1): torch::nn::Linear(in_features=4, out_features=2048," + " bias=true)\n" + " (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (linear2): torch::nn::Linear(in_features=2048, out_features=4," + " bias=true)\n" + " (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n" + " (norm3): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + " )\n" + " )\n" + " (norm): torch::nn::LayerNorm([4], eps=1e-05," + " elementwise_affine=true)\n" + ")"); +} diff --git a/test/cpp/jit/test_custom_operators.cpp b/test/cpp/jit/test_custom_operators.cpp index 064f34c929ec3b5..529b36385bd49c2 100644 --- a/test/cpp/jit/test_custom_operators.cpp +++ b/test/cpp/jit/test_custom_operators.cpp @@ -5,6 +5,7 @@ #include "torch/csrc/jit/ir/irparser.h" #include "torch/csrc/jit/passes/dead_code_elimination.h" #include "torch/csrc/jit/runtime/custom_operator.h" +#include "torch/csrc/jit/runtime/register_ops_utils.h" #include "torch/jit.h" namespace torch { @@ -191,5 +192,68 @@ void testIValueKWargs() { ASSERT_EQ(result.toInt(), 19); } +void testTemplatedOperatorCreator() { + constexpr char op_list[] = "foofoo::bar.template;foo::another"; +#define TORCH_SELECTIVE_NAME_IN_SCHEMA(l, n) \ + torch::detail::SelectiveStr(n) + + { + // Try to register an op name that does not exist in op_list. + // Expected: the op name is not registered. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::not_exist(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::not_exist")); + ASSERT_EQ(ops.size(), 0); + } + + { + // The operator should be successfully registered since its name is in the + // whitelist. + torch::jit::RegisterOperators reg({OperatorGenerator( + TORCH_SELECTIVE_NAME_IN_SCHEMA( + op_list, "foofoo::bar.template(float a, Tensor b) -> Tensor"), + [](Stack* stack) { + double a; + at::Tensor b; + pop(stack, a, b); + push(stack, a + b); + }, + aliasAnalysisFromSchema())}); + + auto& ops = getAllOperatorsFor(Symbol::fromQualString("foofoo::bar")); + ASSERT_EQ(ops.size(), 1); + + auto& op = ops.front(); + ASSERT_EQ(op->schema().name(), "foofoo::bar"); + + ASSERT_EQ(op->schema().arguments().size(), 2); + ASSERT_EQ(op->schema().arguments()[0].name(), "a"); + ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType); + ASSERT_EQ(op->schema().arguments()[1].name(), "b"); + ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType); + + ASSERT_EQ(op->schema().returns().size(), 1); + ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType); + + Stack stack; + push(stack, 2.0f, at::ones(5)); + op->getOperation()(&stack); + at::Tensor output; + pop(stack, output); + + ASSERT_TRUE(output.allclose(at::full(5, 3.0f))); + } +} + } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_irparser.cpp b/test/cpp/jit/test_irparser.cpp index e5d73bae212dfbe..e4e948beca459ec 100644 --- a/test/cpp/jit/test_irparser.cpp +++ b/test/cpp/jit/test_irparser.cpp @@ -338,6 +338,26 @@ graph(%a : Float(*, *, device=cpu), return (%a) )IR"); } + { + auto graph = std::make_shared(); + parseIR( + R"IR( +graph(): + %d : int[] = prim::Constant[value=[1,2,3]]() + return (%d) +)IR", + &*graph); + Node* n = graph->outputs()[0]->node(); + AT_ASSERT(n->kind() == prim::Constant); + AT_ASSERT(n->kindOf(attr::value) == AttributeKind::ival); + const auto& genericList = n->ival(attr::value).toList(); + std::vector int_vals; + for (const IValue& ival : genericList) { + int_vals.push_back(ival.toInt()); + } + AT_ASSERT(int_vals.size() == 3); + AT_ASSERT(int_vals[0] == 1 && int_vals[1] == 2 && int_vals[2] == 3); + } } } // namespace jit } // namespace torch diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 8f6fcace97c4985..d2aa322c2bbd28d 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -8,6 +8,7 @@ #include #include +#include #include "torch/csrc/autograd/generated/variable_factories.h" #include "torch/csrc/autograd/variable.h" #include "torch/csrc/jit/codegen/fuser/interface.h" @@ -45,8 +46,10 @@ #include "torch/csrc/autograd/engine.h" #include "torch/csrc/autograd/variable.h" +#include #include #include + #include "torch/csrc/jit/api/module.h" #include "torch/csrc/jit/frontend/ir_emitter.h" #include "torch/csrc/jit/runtime/profiling_record.h" @@ -1196,6 +1199,84 @@ void testThreadLocalDebugInfo() { } } +void testFallbackGraphs() { + static const auto nestGraphIntoFallbackGraph = + [](const std::shared_ptr& graph) { + ProfilingRecord::removeProfileCounter(graph->block()); + auto fallback = + createFallbackGraph(graph->block(), graph->inputs(), graph.get()); + graph->prependNode(fallback); + for (size_t i = 0; i < graph->outputs().size(); i++) { + graph->outputs()[i]->replaceAllUsesWith(fallback->output(i)); + fallback->output(i)->copyMetadata(graph->outputs()[i]); + } + for (auto it = graph->block()->nodes().rbegin(); + it != fallback->iterator(); + it++) { + it.destroyCurrent(); + } + }; + + auto x = at::randn({1}, at::kCPU); + auto y = at::randn({1}, at::kCPU); + auto stack = createStack({x.clone(), y.clone()}); + + auto graph_string = R"IR( + graph(%0 : Float(1), + %1 : Float(1)): + %2 : Tensor = aten::mul(%0, %1) + %3 : Tensor = aten::mul(%2, %0) + return (%3))IR"; + auto graph = std::make_shared(); + torch::jit::parseIR(graph_string, graph.get()); + + { + Code code(graph, ""); + InterpreterState interpreter{code}; + interpreter.run(stack); + } + at::Tensor et; + pop(stack, et); + float ef = et.item(); + { + EnableProfilingGuard epg; + GraphFunction f("fallbackGraphs", graph, nullptr); + for (size_t i = 0; i < getNumProfiledRuns() + 1; i++) { + stack.emplace_back(x.clone()); + stack.emplace_back(y.clone()); + if (i == getNumProfiledRuns()) { + // we will be modifying a profiled graph + // before ProfilingGraphExecutor + // will optimize it in the next iteration + auto opt_graph = lastExecutedOptimizedGraph(); + // this is safe to do since we are done profiling + ProfilingRecord::removeProfileCounter(opt_graph->block()); + replaceBlockWithFallbackGraph(opt_graph->block()); + GRAPH_DUMP("replaceBlockWithFallbackGraph:", opt_graph); + auto it = opt_graph->block()->nodes().begin(); + ASSERT_EQ(it->kind(), prim::FallbackGraph); + auto fallback = *it++; + ASSERT_EQ(it, opt_graph->block()->nodes().end()); + ASSERT_TRUE(fallback->hasAttribute(attr::Subgraph)); + testing::FileCheck() + .check("Tensor = aten::mul") + ->check("Tensor = aten::mul") + ->run(*fallback->g(attr::Subgraph)); + } + f.run(stack); + at::Tensor at; + pop(stack, at); + float af = at.item(); + ASSERT_EQ(af, ef); + } + + auto opt_graph = lastExecutedOptimizedGraph(); + testing::FileCheck() + .check("(Tensor) = prim::CallFunction") + ->run(*opt_graph); + } +} + void testAutogradProfiler() { constexpr int batch_size = 4; constexpr int input_size = 256; diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 66cd0cebb851271..c1a6b22668e0a47 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -19,6 +19,7 @@ namespace jit { _(CreateAutodiffSubgraphs) \ _(CustomOperators) \ _(CustomOperatorAliasing) \ + _(TemplatedOperatorCreator) \ _(IValueKWargs) \ _(CustomFusion) \ _(SchemaMatching) \ @@ -49,6 +50,7 @@ namespace jit { _(ClassParser) \ _(UnifyTypes) \ _(Profiler) \ + _(FallbackGraphs) \ _(InsertAndEliminateRedundantGuards) \ _(LoopPeeler) \ _(InsertBailOuts) \ diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 4759b07b40a4df4..ad555d7a6219f83 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -327,12 +327,11 @@ void testKernelSumOneAxis() { // Test lowering of sum on one axis. const auto graph_template = R"IR( graph(%0 : Float(5:3,3:1, device=cpu)): - %1 : int = prim::Constant[value=${dim}]() - %2 : int[] = prim::ListConstruct(%1) - %3 : bool = prim::Constant[value=${keepdim}]() - %4 : ${dtype} - %5 : Tensor = aten::sum(%0, %2, %3, %4) - return (%5))IR"; + %1 : int[] = prim::Constant[value=[${dim}]]() + %2 : bool = prim::Constant[value=${keepdim}]() + %3 : ${dtype} + %4 : Tensor = aten::sum(%0, %1, %2, %3) + return (%4))IR"; auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat)); for (int dim = -a.dim(); dim < a.dim(); ++dim) { diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py new file mode 100644 index 000000000000000..74582dc9919ffa9 --- /dev/null +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -0,0 +1,175 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import os + +import numpy as np +import torch +import torch.distributed as c10d +from torch import nn +from torch.distributed.algorithms.ddp_comm_hooks import ( + DDPCommHookType, + register_ddp_comm_hook, +) +from torch.nn.parallel import DistributedDataParallel +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, + requires_nccl, + skip_if_lt_x_gpu, +) +from torch.testing._internal.common_utils import run_tests + + +def gpus_for_rank(world_size): + visible_devices = list(range(torch.cuda.device_count())) + gpus_per_process = torch.cuda.device_count() // world_size + gpus_for_rank = [] + for rank in range(world_size): + gpus_for_rank.append( + visible_devices[rank * gpus_per_process : (rank + 1) * gpus_per_process] + ) + return gpus_for_rank + + +class Task(nn.Module): + def __init__(self): + super(Task, self).__init__() + torch.manual_seed(0) + self.p = nn.Parameter(torch.randn(40, 20)) + + def forward(self, x): + return self.p * x + + +class TestDdpCommHook(nn.Module): + def __init__(self): + super().__init__() + self.t0 = Task() + + def forward(self, x, rank): + return self.t0(x ** (1 + rank)) + + +class DistributedDataParallelCommHookTest(MultiProcessTestCase): + def setUp(self): + super(DistributedDataParallelCommHookTest, self).setUp() + self._fork_processes() + + def tearDown(self): + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + return 2 + + def _local_model(self): + local_model = TestDdpCommHook().cpu() + + return local_model + + def _get_grads(self, process_group, hook_type=None): + device_id = gpus_for_rank(self.world_size)[self.rank][0] + gpu_model = DistributedDataParallel( + TestDdpCommHook().to(device_id), + device_ids=[device_id], + process_group=process_group, + ) + + # Register DDP Communication Hook if defined + if hook_type is not None: + register_ddp_comm_hook( + comm_hook_type=hook_type, model=gpu_model, state=process_group + ) + + return self._run_and_get_grads(gpu_model) + + def _run_and_get_grads(self, model): + torch.manual_seed(2020) + input = torch.randn(40, 20) + # Run forward + output = model(input, self.rank) + + # Run backward + output.mean().backward() + + return [p.grad.data.cpu().numpy() for p in model.parameters()] + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_allreduce_hook(self): + """ + This unit test verifies the ``allreduce`` hook registered case gives same result + with no hook registered case. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads(process_group, DDPCommHookType.ALLREDUCE) + + np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=0) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_fp16compress_hook(self): + """ + This unit test verifies the ``fp16 compress`` hook registered case + gives close result with no hook registered case. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads(process_group, DDPCommHookType.FP16_COMPRESS) + + np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_quantize_per_tensor_hook(self): + """ + This unit test verifies the ``quantize per tensor`` hook registered case + gives close result with no hook registered case. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads(process_group, DDPCommHookType.QUANTIZE_PER_TENSOR) + + np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_ddp_comm_hook_quantize_per_channel_hook(self): + """ + This unit test verifies the ``quantize per channel`` hook registered case + gives close result with no hook registered case. + """ + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) + + # No hook registered case, get the reference grads. + reference_grads = self._get_grads(process_group, None) + # Register hook case, get the hook grads. + hook_grads = self._get_grads( + process_group, DDPCommHookType.QUANTIZE_PER_CHANNEL + ) + + np.testing.assert_allclose(hook_grads, reference_grads, rtol=1e-5, atol=1e-4) + + +if __name__ == "__main__": + assert ( + not torch.cuda._initialized + ), "test_distributed must not have initialized CUDA context on main process" + + run_tests() diff --git a/test/jit/test_async.py b/test/jit/test_async.py index bb7b3a51e454c39..22a1a15680d5208 100644 --- a/test/jit/test_async.py +++ b/test/jit/test_async.py @@ -9,7 +9,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from torch.testing._internal.jit_utils import JitTestCase, _inline_everything -from torch.testing._internal.common_utils import TemporaryFileName from typing import List, Tuple from torch import Tensor @@ -491,51 +490,6 @@ def forward(self, input): self.checkTrace(TestModule(), (torch.randn(5, 5),)) - def test_save_load_with_extra_files(self): - class MyMod(torch.jit.ScriptModule): - @torch.jit.script_method - def forward(self, a): - return a - - expected_extra_files = torch._C.ExtraFilesMap() - expected_extra_files['foo'] = 'bar' - m = MyMod() - - # Save to file. - with TemporaryFileName() as fname: - m.save(fname, _extra_files=expected_extra_files) - extra_files = torch._C.ExtraFilesMap() - extra_files['foo'] = '' - torch.jit.load(fname, _extra_files=extra_files) - self.assertEqual('bar', extra_files['foo']) - - # Use torch.jit API - torch.jit.save(m, fname, _extra_files=expected_extra_files) - extra_files['foo'] = '' - torch.jit.load(fname, _extra_files=extra_files) - self.assertEqual('bar', extra_files['foo']) - - # Save to buffer. - buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files)) - extra_files = torch._C.ExtraFilesMap() - extra_files['foo'] = '' - torch.jit.load(buffer, _extra_files=extra_files) - self.assertEqual('bar', extra_files['foo']) - - # Use torch.jit API - buffer = io.BytesIO() - torch.jit.save(m, buffer, _extra_files=expected_extra_files) - buffer.seek(0) - extra_files = torch._C.ExtraFilesMap() - extra_files['foo'] = '' - torch.jit.load(buffer, _extra_files=extra_files) - self.assertEqual('bar', extra_files['foo']) - - # Non-existent file 'bar' - with self.assertRaises(RuntimeError): - extra_files['bar'] = '' - torch.jit.load(buffer, _extra_files=extra_files) - def test_no_future_subtype_message(self): with self.assertRaisesRegex(RuntimeError, 'Future without a contained type'): @torch.jit.script diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py new file mode 100644 index 000000000000000..737bc4528102b55 --- /dev/null +++ b/test/jit/test_profiler.py @@ -0,0 +1,67 @@ +import os +import sys + +import torch + +# Make the helper files in test/ importable +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) +from torch.testing._internal.jit_utils import JitTestCase, warmup_backward, FileCheck + +if __name__ == '__main__': + raise RuntimeError("This test file is not meant to be run directly, use:\n\n" + "\tpython test/test_jit.py TESTNAME\n\n" + "instead.") + +class TestProfiler(JitTestCase): + def setUp(self): + self.prev_exec = torch._C._jit_set_profiling_executor(True) + self.prev_profiling = torch._C._jit_set_profiling_mode(True) + self.inline_autodiff = torch._C._debug_set_autodiff_subgraph_inlining(False) + self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() + torch._C._jit_set_texpr_fuser_enabled(True) + + + def tearDown(self): + torch._C._jit_set_profiling_executor(self.prev_exec) + torch._C._jit_set_profiling_mode(self.prev_profiling) + torch._C._debug_set_autodiff_subgraph_inlining(self.inline_autodiff) + torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) + + def test_specialize_backward(self): + def test_fuse(a, b): + c = a * b + d = c * b + return d + + test_fuse.__disable_jit_function_caching__ = True + + scripted_f = torch.jit.script(test_fuse) + x = torch.ones(1, requires_grad=True) + y = torch.ones(1, requires_grad=True) + scripted_f(x, y) + b = scripted_f(x, y) + warmup_backward(b) + g = torch.jit.last_executed_optimized_graph() + # Backward has an if node guarding specializations, + # within the if node true block there is only one if node + # that guards a tensorexpr group + optimized_block = next(g.findNode("prim::If").blocks()) + if_nodes = list(optimized_block.findAllNodes("prim::If")) + self.assertEqual(len(if_nodes), 1) + FileCheck().check("Group[Subgraph").run(str(if_nodes[0])) + # no broadcasts occurred, sum_to_size have been specialized out + self.assertIsNone(optimized_block.findNode("aten::_grad_sum_to_size")) + + broadcast_f = torch.jit.script(test_fuse) + x = torch.ones([2, 2], requires_grad=True) + y = torch.ones([1], requires_grad=True) + broadcast_f(x, y) + b = broadcast_f(x, y) + b.backward(torch.ones([2, 2], dtype=torch.float)) + b.backward(torch.ones([2, 2], dtype=torch.float)) + # warmup_backward(b, torch.ones([2, 2], dtype=torch.float)) + g = torch.jit.last_executed_optimized_graph() + optimized_block = next(g.findNode("prim::If").blocks()) + # broadcasts occurred, currently expect to see aten::_grad_sum_to_size + self.assertIsNotNone(optimized_block.findNode("aten::_grad_sum_to_size")) diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 79216891b403e8e..c1ea2b09f103698 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -5,6 +5,7 @@ import torch from itertools import product as product from torch import Tensor +from torch.testing._internal.common_utils import TemporaryFileName from typing import NamedTuple # Make the helper files in test/ importable @@ -874,3 +875,53 @@ def forward(self, x): torch.jit.save(sm, contains_both) contains_both.seek(0) sm = torch.jit.load(contains_both) + + def test_save_load_with_extra_files(self): + class MyMod(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return a + + # specifically test binary data + value = b"bar\x00\xffbaz" + + expected_extra_files = {} + expected_extra_files['foo'] = value + # verify that str to bytes conversion also works + expected_extra_files['foo2'] = "bar" + m = MyMod() + + # Save to file. + with TemporaryFileName() as fname: + m.save(fname, _extra_files=expected_extra_files) + # values don't matter + extra_files = {'foo': '', 'foo2': None} + torch.jit.load(fname, _extra_files=extra_files) + self.assertEqual(value, extra_files['foo']) + # results come back always as bytes + self.assertEqual(b"bar", extra_files['foo2']) + + # Use torch.jit API + torch.jit.save(m, fname, _extra_files=expected_extra_files) + extra_files['foo'] = '' + torch.jit.load(fname, _extra_files=extra_files) + self.assertEqual(value, extra_files['foo']) + + # Save to buffer. + buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files)) + extra_files = {'foo': ''} + torch.jit.load(buffer, _extra_files=extra_files) + self.assertEqual(value, extra_files['foo']) + + # Use torch.jit API + buffer = io.BytesIO() + torch.jit.save(m, buffer, _extra_files=expected_extra_files) + buffer.seek(0) + extra_files = {'foo': ''} + torch.jit.load(buffer, _extra_files=extra_files) + self.assertEqual(value, extra_files['foo']) + + # Non-existent file 'bar' + with self.assertRaises(RuntimeError): + extra_files['bar'] = '' + torch.jit.load(buffer, _extra_files=extra_files) diff --git a/test/jit/test_with.py b/test/jit/test_with.py index a73a75bc3e3ce76..e79f838217d68d8 100644 --- a/test/jit/test_with.py +++ b/test/jit/test_with.py @@ -548,3 +548,65 @@ def test_exit_incorrect_types(x, c): self.checkScript( test_exit_incorrect_types, (test_tensor, ExitIncorrectTypes()) ) + + def test_with_no_grad(self): + """ + Check that torch.no_grad() works. Most of these are adapted from + corresponding tests for eager-mode no_grad. + """ + + # Basic no_grad test. + def test_no_grad(x, y): + # type: (Tensor, Tensor) -> Tensor + with torch.no_grad(): + w = x + y + + return w + + s = torch.jit.script(test_no_grad) + x = torch.ones(5, 5, requires_grad=True) + y = torch.ones(5, 5) * 4 + w = s(x, y) + + self.assertFalse(w.requires_grad) + self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) + self.assertIsNone(w.grad_fn) + + # Test assignment of a grad-less Tensor to a Tensor with gradients + # in a no_grad block. + def test_no_grad_assignment(x, y): + # type: (Tensor, Tensor) -> Tensor + with torch.no_grad(): + x[0] = y + + return x + + s = torch.jit.script(test_no_grad_assignment) + z = torch.randn(5) + w = s(x, z) + self.assertTrue(w.requires_grad) + self.assertIsNone(w.grad_fn) + + # Check that @torch.jit.ignored functions respect no_grad when it is + # called in JIT mode. + class NoGradModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @torch.jit.ignore + def adder(self, x, y): + # type: (Tensor, Tensor) -> Tensor + w = x + y + return w + + def forward(self, x, y): + # type: (Tensor, Tensor) -> Tensor + with torch.no_grad(): + w = self.adder(x, y) + + return w + + s = torch.jit.script(NoGradModule()) + w = s(x, y) + + self.assertFalse(w.requires_grad) diff --git a/test/mobile/custom_build/build.sh b/test/mobile/custom_build/build.sh index b0b6d3a6bc66e9a..5affa6d3ab88162 100755 --- a/test/mobile/custom_build/build.sh +++ b/test/mobile/custom_build/build.sh @@ -4,23 +4,18 @@ # size for mobile devices and the flow to integrate it with a simple predictor # in c++. # -# There are three custom build types: +# Supported custom build types: # # 1. `TEST_DEFAULT_BUILD=1 ./build.sh` - it is similar to the prebuilt libtorch # libraries released for Android and iOS (same CMake build options + host # toolchain), which doesn't contain autograd function nor backward ops thus is # smaller than full LibTorch. # -# 2. `TEST_CUSTOM_BUILD_STATIC=1 ./build.sh` - it further optimizes libtorch +# 2. `TEST_CUSTOM_BUILD_DYNAMIC=1 ./build.sh` - it further optimizes libtorch # size by only including ops used by a specific model. -# -# 3. `TEST_CUSTOM_BUILD_DYNAMIC=1 ./build.sh` - similar as 2) except that it -# relies on the op dependency graph (instead of static dispatch) to calculate -# and keep all transitively dependent ops by the model. # Note that LLVM_DIR environment variable should be set to the location of # LLVM-dev toolchain. # -# Type 2) will be deprecated by type 3) in the future. ############################################################################### set -ex -o pipefail @@ -59,17 +54,6 @@ run_default_build() { "${SRC_ROOT}/scripts/build_mobile.sh" } -run_custom_build_with_static_dispatch() { - LIBTORCH_BUILD_ROOT="${BUILD_ROOT}/build_custom_libtorch_static" - LIBTORCH_INSTALL_PREFIX="${LIBTORCH_BUILD_ROOT}/install" - - BUILD_ROOT="${LIBTORCH_BUILD_ROOT}" \ - "${SRC_ROOT}/scripts/build_mobile.sh" \ - -DCMAKE_CXX_FLAGS="-DSTRIP_ERROR_MESSAGES" \ - -DUSE_STATIC_DISPATCH=ON \ - -DSELECTED_OP_LIST="${ROOT_OPS}" -} - run_custom_build_with_dynamic_dispatch() { LIBTORCH_BUILD_ROOT="${BUILD_ROOT}/build_custom_libtorch_dynamic" LIBTORCH_INSTALL_PREFIX="${LIBTORCH_BUILD_ROOT}/install" @@ -77,7 +61,6 @@ run_custom_build_with_dynamic_dispatch() { BUILD_ROOT="${LIBTORCH_BUILD_ROOT}" \ "${SRC_ROOT}/scripts/build_mobile.sh" \ -DCMAKE_CXX_FLAGS="-DSTRIP_ERROR_MESSAGES" \ - -DUSE_STATIC_DISPATCH=OFF \ -DSELECTED_OP_LIST="${ROOT_OPS}" \ -DOP_DEPENDENCY="${OP_DEPENDENCY}" } @@ -115,13 +98,6 @@ test_default_build() { run_predictor } -test_custom_build_with_static_dispatch() { - prepare_model_and_dump_root_ops - run_custom_build_with_static_dispatch - build_predictor - run_predictor -} - test_custom_build_with_dynamic_dispatch() { prepare_model_and_dump_root_ops generate_op_dependency_graph @@ -134,10 +110,6 @@ if [ -n "${TEST_DEFAULT_BUILD}" ]; then test_default_build fi -if [ -n "${TEST_CUSTOM_BUILD_STATIC}" ]; then - test_custom_build_with_static_dispatch -fi - if [ -n "${TEST_CUSTOM_BUILD_DYNAMIC}" ]; then test_custom_build_with_dynamic_dispatch fi diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index 4e80f86bb9c38c9..5959e9c7a5b5b11 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -1774,6 +1774,20 @@ def forward(self, input, indices): indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) self.run_test(ScatterModel(), input=(input, indices)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_scatter_with_scalar_different_types(self): + # Tests the case when scalar src (updates values) type is different + # from self type. Happens only with scalar src - PyTorch does not + # allow this when src is a tensor. + class ScatterModel(torch.nn.Module): + def forward(self, input, indices): + values = 1.0 + return input.scatter(1, indices, values) + + input = torch.tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=torch.float32) + indices = torch.tensor([[1, 0], [0, 1], [0, 1]], dtype=torch.int64) + self.run_test(ScatterModel(), input=(input, indices)) + @skipIfUnsupportedMinOpsetVersion(9) def test_scatter(self): class ScatterModel(torch.nn.Module): @@ -3821,6 +3835,34 @@ def forward(self, input): input = torch.randn(2, 5, 7, dtype=torch.float64) self.run_test(Celu(), (input,)) + @skipIfUnsupportedMinOpsetVersion(9) + def test_where(self): + class Model(torch.nn.Module): + def forward(self, cond, input, other): + return torch.where(cond, input, other) + + x = torch.randint(0, 1, (2, 3, 4), dtype=torch.bool) + y = torch.randn(2, 1, 4) + z = torch.ones(2, 3, 1) + self.run_test(Model(), (x, y, z)) + + @skipIfUnsupportedMinOpsetVersion(9) + def test_where_condition(self): + class Model1(torch.nn.Module): + def forward(self, input): + return torch.stack(torch.where(input > 0.5), dim=1) + + x = torch.randint(0, 2, (2, 3, 4), dtype=bool) + self.run_test(Model1(), (x)) + + class Model2(torch.nn.Module): + def forward(self, input, other): + return torch.stack(torch.where(input > other), dim=1) + + x = torch.randint(0, 1, (2, 3, 4), dtype=bool) + y = torch.randint(1, 2, (2, 3, 4), dtype=bool) + self.run_test(Model2(), (x, y)) + def test_empty_branch(self): class EmptyBranchModel(torch.jit.ScriptModule): @torch.jit.script_method diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 2dde1cabd17c892..574bb5fcf5a2bc1 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -520,6 +520,25 @@ def forward(self, x): # verify that the model state is preserved assert model.training == old_state + def test_diagnose_export_mode(self): + class MyModule(torch.nn.Module): + def forward(self, x): + return torch.cumsum(x, dim=0) + + model = MyModule() + x = torch.randn(2, 3, 4) + f = io.BytesIO() + + # run export in diagnose mode + graph, unsupported_ops = torch.onnx._diagnose_export(model, (x,), f, + opset_version=9) + iter = graph.nodes() + assert next(iter).kind() == "onnx::Constant" + assert next(iter).kind() == "prim::Constant" + assert next(iter).kind() == "aten::cumsum" + assert len(unsupported_ops) == 1 + assert unsupported_ops == ['aten::cumsum'] + def test_dropout_training(self): class MyModule(torch.nn.Module): def __init__(self): diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph.expected.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph.expected.pt new file mode 100644 index 000000000000000..0a70645ce35b77f Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph.input.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph.input.pt new file mode 100644 index 000000000000000..3a16e51b9b24829 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph.scripted.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph.scripted.pt new file mode 100644 index 000000000000000..46e79a82cc25b24 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph.scripted.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph.traced.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph.traced.pt new file mode 100644 index 000000000000000..817f1e1c8f91f56 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph.traced.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.expected.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.expected.pt new file mode 100644 index 000000000000000..4d5ab4035079679 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.input.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.input.pt new file mode 100644 index 000000000000000..2bed991b9449cb1 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.scripted.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.scripted.pt new file mode 100644 index 000000000000000..ebd68aefa5ec1f1 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.scripted.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.traced.pt b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.traced.pt new file mode 100644 index 000000000000000..230743c8fb6db27 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_graph_v2.traced.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias.expected.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.expected.pt new file mode 100644 index 000000000000000..e48fb09aeaabd49 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias.input.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.input.pt new file mode 100644 index 000000000000000..a9748471efffb52 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias.scripted.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.scripted.pt new file mode 100644 index 000000000000000..25613c57c08d4f7 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.scripted.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias.state_dict.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.state_dict.pt new file mode 100644 index 000000000000000..35a88e4d91fd678 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.state_dict.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias.traced.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.traced.pt new file mode 100644 index 000000000000000..6bed65123b30a3c Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias.traced.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.expected.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.expected.pt new file mode 100644 index 000000000000000..b997a2c07b2afca Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.input.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.input.pt new file mode 100644 index 000000000000000..f4f7e10ec72cef7 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.scripted.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.scripted.pt new file mode 100644 index 000000000000000..3c26ffb4ab36aed Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.scripted.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.traced.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.traced.pt new file mode 100644 index 000000000000000..a615972d9bb3fd1 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph.traced.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.expected.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.expected.pt new file mode 100644 index 000000000000000..31eb62469b3b816 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.input.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.input.pt new file mode 100644 index 000000000000000..65b652e92089f07 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.scripted.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.scripted.pt new file mode 100644 index 000000000000000..eb6cc6958f22c00 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.scripted.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.traced.pt b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.traced.pt new file mode 100644 index 000000000000000..630c088fc0f012b Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_conv2d_nobias_graph_v2.traced.pt differ diff --git a/test/quantization/test_backward_compatibility.py b/test/quantization/test_backward_compatibility.py index e6d77e39fa02c3f..f4e8518de5195ef 100644 --- a/test/quantization/test_backward_compatibility.py +++ b/test/quantization/test_backward_compatibility.py @@ -5,6 +5,7 @@ # torch import torch +import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic.quantized as nniq @@ -13,6 +14,35 @@ from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm +def remove_prefix(text, prefix): + if text.startswith(prefix): + return text[len(prefix):] + return text + +def get_filenames(self, subname): + # NB: we take __file__ from the module that defined the test + # class, so we place the expect directory where the test script + # lives, NOT where test/common_utils.py lives. + module_id = self.__class__.__module__ + munged_id = remove_prefix(self.id(), module_id + ".") + test_file = os.path.realpath(sys.modules[module_id].__file__) + base_name = os.path.join(os.path.dirname(test_file), + "serialized", + munged_id) + + subname_output = "" + if subname: + base_name += "_" + subname + subname_output = " ({})".format(subname) + + input_file = base_name + ".input.pt" + state_dict_file = base_name + ".state_dict.pt" + scripted_module_file = base_name + ".scripted.pt" + traced_module_file = base_name + ".traced.pt" + expected_file = base_name + ".expected.pt" + + return input_file, state_dict_file, scripted_module_file, traced_module_file, expected_file + class TestSerialization(TestCase): """ Test backward compatiblity for serialization and numerics """ @@ -23,30 +53,8 @@ def _test_op(self, qmodule, subname=None, input_size=None, input_quantized=True, with current code, make sure we don't break backward compatibility for the serialization of quantized modules """ - def remove_prefix(text, prefix): - if text.startswith(prefix): - return text[len(prefix):] - return text - # NB: we take __file__ from the module that defined the test - # class, so we place the expect directory where the test script - # lives, NOT where test/common_utils.py lives. - module_id = self.__class__.__module__ - munged_id = remove_prefix(self.id(), module_id + ".") - test_file = os.path.realpath(sys.modules[module_id].__file__) - base_name = os.path.join(os.path.dirname(test_file), - "serialized", - munged_id) - - subname_output = "" - if subname: - base_name += "_" + subname - subname_output = " ({})".format(subname) - - input_file = base_name + ".input.pt" - state_dict_file = base_name + ".state_dict.pt" - scripted_module_file = base_name + ".scripted.pt" - traced_module_file = base_name + ".traced.pt" - expected_file = base_name + ".expected.pt" + input_file, state_dict_file, scripted_module_file, traced_module_file, expected_file = \ + get_filenames(self, subname) # only generate once. if generate and qengine_is_fbgemm(): @@ -69,6 +77,51 @@ def remove_prefix(text, prefix): self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) + def _test_op_graph(self, qmodule, subname=None, input_size=None, input_quantized=True, + generate=False, prec=None, new_zipfile_serialization=False): + r""" + Input: a floating point module + + If generate == True, traces and scripts the module and quantizes the results with + PTQ, and saves the results. + + If generate == False, traces and scripts the module and quantizes the results with + PTQ, and compares to saved results. + """ + input_file, state_dict_file, scripted_module_file, traced_module_file, expected_file = \ + get_filenames(self, subname) + + # only generate once. + if generate and qengine_is_fbgemm(): + input_tensor = torch.rand(*input_size).float() + torch.save(input_tensor, input_file) + + # convert to TorchScript + scripted = torch.jit.script(qmodule) + traced = torch.jit.trace(qmodule, input_tensor) + + # quantize + + def _eval_fn(model, data): + model(data) + + qconfig_dict = {'': torch.quantization.default_qconfig} + scripted_q = torch.quantization.quantize_jit( + scripted, qconfig_dict, _eval_fn, [input_tensor]) + traced_q = torch.quantization.quantize_jit( + traced, qconfig_dict, _eval_fn, [input_tensor]) + + torch.jit.save(scripted_q, scripted_module_file) + torch.jit.save(traced_q, traced_module_file) + torch.save(scripted_q(input_tensor), expected_file) + + input_tensor = torch.load(input_file) + qmodule_scripted = torch.jit.load(scripted_module_file) + qmodule_traced = torch.jit.load(traced_module_file) + expected = torch.load(expected_file) + self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) + self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) + @override_qengines def test_linear(self): module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8) @@ -92,7 +145,52 @@ def test_conv2d(self): module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros") self._test_op(module, input_size=[1, 3, 6, 6], generate=False) - # TODO: graph mode quantized conv2d module + + @override_qengines + def test_conv2d_nobias(self): + module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, + groups=1, bias=False, padding_mode="zeros") + self._test_op(module, input_size=[1, 3, 6, 6], generate=False) + + @override_qengines + def test_conv2d_graph(self): + module = nn.Sequential( + torch.quantization.QuantStub(), + nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode="zeros"), + ) + self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) + + @override_qengines + def test_conv2d_nobias_graph(self): + module = nn.Sequential( + torch.quantization.QuantStub(), + nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, + groups=1, bias=False, padding_mode="zeros"), + ) + self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) + + @override_qengines + def test_conv2d_graph_v2(self): + # tests the same thing as test_conv2d_graph, but for version 2 of + # ConvPackedParams{n}d + module = nn.Sequential( + torch.quantization.QuantStub(), + nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, + groups=1, bias=True, padding_mode="zeros"), + ) + self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) + + @override_qengines + def test_conv2d_nobias_graph_v2(self): + # tests the same thing as test_conv2d_nobias_graph, but for version 2 of + # ConvPackedParams{n}d + module = nn.Sequential( + torch.quantization.QuantStub(), + nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, + groups=1, bias=False, padding_mode="zeros"), + ) + self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) @override_qengines def test_conv2d_relu(self): diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index 658c722c1b3e43b..1eeafcf605e2806 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -4,6 +4,7 @@ import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd import torch.nn.intrinsic.quantized as nniq +import torch.multiprocessing as mp # symbolic trace from torch.fx import symbolic_trace @@ -11,21 +12,34 @@ # graph mode quantization based on fx from torch.quantization._quantize_fx import ( Quantizer, + fuse, QuantType, ) -from torch.quantization import default_qconfig +from torch.quantization import ( + default_qconfig, + default_qat_qconfig, + prepare, + prepare_qat, + convert, +) # test utils from torch.testing._internal.common_quantization import ( QuantizationTestCase, skipIfNoFBGEMM, + skip_if_no_torchvision, + train_one_epoch, + run_ddp, ) +from torch.testing._internal.common_distributed import skip_if_not_multigpu + from torch.testing._internal.common_quantization import NodeSpec as ns import itertools import operator +import unittest class TestQuantizeFx(QuantizationTestCase): """ Unit tests for functionalities @@ -757,3 +771,202 @@ def forward(self, x): quantized, expected_node_occurrence=count_check, expected_node_list=order_check) + +class TestQuantizeFxModels(QuantizationTestCase): + def _test_model_impl( + self, mode, name, model, eager_quantizable_model, + check_with_eager=True, + diff_of_quant=None, + diff_from_eager=None): + if diff_of_quant is None or diff_from_eager is None: + diff_of_quant = {} + diff_from_eager = {} + + if mode not in diff_of_quant or mode not in diff_from_eager: + diff_of_quant[mode] = {} + diff_from_eager[mode] = {} + + input_tensor = torch.rand(1, 3, 224, 224) + input_tensor_inception = torch.rand(1, 3, 299, 299) + output_value = torch.randint(0, 1, (1,)) + + # print('quantizing:', name, ' mode:', mode) + if name == 'inception_v3': + input_value = input_tensor_inception + else: + input_value = input_tensor + + qconfig = default_qconfig if mode == 'static' else default_qat_qconfig + qconfig_dict = {'': qconfig} + graph_module = symbolic_trace(model) + # print('graph module:', graph_module.src) + script = torch.jit.script(graph_module) + + # make sure graph module and script module are both runanble + original_out = graph_module(input_value) + is_not_tuple_out = not isinstance(original_out, tuple) + script_out = script(input_value) + self.assertEqual( + (original_out - script_out).abs().max(), 0, + 'Reslut of original graph module and script module does not match') + + # set to train just before quantization + if mode != 'static': + model.train() + + graph_module = fuse(graph_module) + quantizer = Quantizer() + prepared = quantizer.prepare(graph_module, qconfig_dict) + + if mode == 'ddp': + mp.spawn(run_ddp, + args=(world_size, prepared), + nprocs=world_size, + join=True) + elif mode == 'qat': + assert prepared.training, 'prepared must be in training mode for qat' + optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001) + criterion = nn.CrossEntropyLoss() + train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) + else: + for i in range(10): + prepared(input_value) + + # print('after observation root:', prepared.root) + + qgraph = quantizer.convert(prepared) + # print('after quantization root:', qgraph.root) + # print('after quantization code:', qgraph.src) + qgraph.eval() + qgraph_script = torch.jit.script(qgraph) + # print('quantized and scripted:', qgraph_script.graph) + + qgraph_out = qgraph(input_value) + qgraph_script = qgraph_script(input_value) + + if is_not_tuple_out: + diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max() + assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph' + else: + print('tuple output') + + if eager_quantizable_model is not None: + # comparing to eager mode quantization + qeager = eager_quantizable_model + ref_out = qeager(input_value) + qeager.qconfig = qconfig + if mode == 'static': + qeager.fuse_model() + prepare(qeager, inplace=True) + else: + qeager.train() + qeager.fuse_model() + prepare_qat(qeager, inplace=True) + + # calibration + if mode == 'ddp': + mp.spawn(run_ddp, + args=(world_size, qeager), + nprocs=world_size, + join=True) + elif mode == 'qat': + assert qeager.training, 'qeager should be in training mode for qat' + optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001) + train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1) + else: + for i in range(10): + qeager(input_value) + + # print('ref after observation:', qeager) + + convert(qeager, inplace=True) + qeager.eval() + + # print('ref after quantization:', qeager) + qeager_out = qeager(input_value) + qeager_script = torch.jit.script(qeager) + qscript_out = qeager_script(input_value) + if is_not_tuple_out: + diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max() + if check_with_eager: + self.assertEqual(diff_from_eager[mode][name], 0, + 'Result of graph mode quantization and ' + + 'eager mode quantization on model: ' + name + + ' should match. Mode: ' + mode + + ' diff:' + str(diff_from_eager[mode][name])) + + @skip_if_no_torchvision + @skipIfNoFBGEMM + @unittest.skip("skip for now since tbb failed") + def test_torchvision(self): + from torchvision import models + from torchvision.models import quantization as quantized_models + + def get_available_classification_models(models): + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + model_list = get_available_classification_models(models) + quantized_model_list = get_available_classification_models(quantized_models) + + no_pretrained_model = set(['shufflenet_v2_x0_5', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0']) + quantized_model_list = set(quantized_model_list) - no_pretrained_model + # test eager and graph consistency + model_list = quantized_model_list + # slice need to be fixed in symbolic tracing(https://github.com/pytorch/pytorch/issues/43511) + model_list = set(model_list) - {'googlenet', 'inception_v3'} + # getattr should not be used as node name(https://github.com/pytorch/pytorch/issues/43522) + model_list -= {'shufflenet_v2_x1_0', 'mobilenet_v2'} + + # mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8' + # incpetion_v3: looks like there is some problem with AuxLogits + quantized_not_working = [('qat', 'mobilenet_v2'), + ('qat', 'inception_v3'), + ('static', 'inception_v3')] + + fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager + 'mobilenet_v2'] # because relu6 is replaced as relu in mobilenetv2 + + diff_of_quant = {} + diff_from_eager = {} + modes = ['static', 'qat'] + options = itertools.product(modes, model_list) + for mode, name in options: + pretrained = name in quantized_model_list # load pretrained model to compare with quantized model + if name in quantized_model_list: + if (mode, name) in quantized_not_working: + eager_quantizable_model = None + else: + eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float() + # compare with eager mode quantized model when it is available + pretrained = eager_quantizable_model is not None + model = models.__dict__[name](pretrained=pretrained).eval().float() + check_with_eager = name not in fx_eager_not_matching + self._test_model_impl( + mode, name, model, eager_quantizable_model, + check_with_eager, + diff_of_quant, diff_from_eager) + + def print_diffs(diffs): + for mode, diffs_for_mode in diffs.items(): + print('mode:', mode) + for name, diff in diffs_for_mode.items(): + print(name, ':', diff) + + # print('differences between float and quantized') + # print_diffs(diff_of_quant) + # print('----------------------') + # print('differences between graph mode and eager mode') + # print_diffs(diff_from_eager) + # print('----------------------') + + @skip_if_no_torchvision + @skip_if_not_multigpu + @skipIfNoFBGEMM + @unittest.skip('TODO: not working yet due to https://github.com/pytorch/pytorch/issues/43513') + def test_resnet18_ddp(self): + from torchvision import models + from torchvision.models import quantization as quantized_models + eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float() + model = models.__dict__[name](pretrained=True).eval().float() + self._test_model_impl( + 'ddp', 'resnet18', model, eager_quantizable_model) diff --git a/test/run_test.py b/test/run_test.py index fcdce3665e1f0ef..5a94eb630604704 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -140,6 +140,7 @@ 'distributed/test_distributed', 'distributed/rpc/test_process_group_agent', 'distributed/rpc/test_tensorpipe_agent', + 'distributed/algorithms/ddp_comm_hooks/test_ddp_hooks', 'test_cuda', 'test_cuda_primary_ctx', 'test_cpp_extensions_aot_ninja', diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 82e871a45497eca..06f3df93a99e260 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -285,6 +285,7 @@ def test_inline_jit_compile_extension_multiple_sources_and_no_functions(self): z = module.sin_add(x, y) self.assertEqual(z, x.sin() + y.sin()) + @unittest.skip("Temporarily disabled") @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") def test_inline_jit_compile_extension_cuda(self): cuda_source = """ @@ -327,6 +328,7 @@ def test_inline_jit_compile_extension_cuda(self): z = module.cos_add(x, y) self.assertEqual(z, x.cos() + y.cos()) + @unittest.skip("Temporarily disabled") @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") def test_inline_jit_compile_custom_op_cuda(self): cuda_source = """ @@ -401,6 +403,7 @@ def test_lenient_flag_handling_in_jit_extensions(self): z = module.tanh_add(x, y).cpu() self.assertEqual(z, x.tanh() + y.tanh()) + @unittest.skip("Temporarily disabled") @unittest.skipIf(not (TEST_CUDA or TEST_ROCM), "CUDA not found") def test_half_support(self): """ diff --git a/test/test_foreach.py b/test/test_foreach.py index 388e9261b676ee6..b8c83b9b681f37b 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1,56 +1,113 @@ import torch -import torch.cuda from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes class TestForeach(TestCase): @dtypes(*torch.testing.get_all_dtypes()) - def test_add_scalar_with_same_size_tensors(self, device, dtype): - N = 20 - H = 20 - W = 20 - tensors = [] - for _ in range(N): - tensors.append(torch.zeros(H, W, device=device, dtype=dtype)) + def test_int_scalar(self, device, dtype): + tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] + int_scalar = 1 - res = torch._foreach_add(tensors, 1) - for t in res: - if dtype == torch.bool: - dtype = torch.int64 - self.assertEqual(t, torch.ones(H, W, device=device, dtype=dtype)) + # bool tensor + 1 will result in int64 tensor + if dtype == torch.bool: + expected = [torch.ones(10, 10, device=device, dtype=torch.int64) for _ in range(10)] + else: + expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] + + res = torch._foreach_add(tensors, int_scalar) + self.assertEqual(res, expected) + + if dtype in [torch.bool]: + with self.assertRaisesRegex(RuntimeError, "result type Long can't be cast to the desired output type Bool"): + torch._foreach_add_(tensors, int_scalar) + else: + torch._foreach_add_(tensors, int_scalar) + self.assertEqual(res, tensors) @dtypes(*torch.testing.get_all_dtypes()) - def test_add_scalar_with_different_size_tensors(self, device, dtype): - N = 20 - H = 20 - W = 20 + def test_float_scalar(self, device, dtype): + tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] + float_scalar = 1. + + # float scalar + integral tensor will result in float tensor + if dtype in [torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64, torch.bool]: + expected = [torch.ones(10, 10, device=device, dtype=torch.float32) for _ in range(10)] + else: + expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] + + res = torch._foreach_add(tensors, float_scalar) + self.assertEqual(res, expected) - tensors = [] - size_change = 0 - for _ in range(N): - tensors.append(torch.zeros(H + size_change, W + size_change, device=device, dtype=dtype)) - size_change += 1 + if dtype in [torch.uint8, torch.int8, torch.int16, + torch.int32, torch.int64, torch.bool]: + self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, float_scalar)) + else: + torch._foreach_add_(tensors, float_scalar) + self.assertEqual(res, tensors) - res = torch._foreach_add(tensors, 1) + @dtypes(*torch.testing.get_all_dtypes()) + def test_complex_scalar(self, device, dtype): + tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] + complex_scalar = 3 + 5j + + # bool tensor + 1 will result in int64 tensor + expected = [torch.add(complex_scalar, torch.zeros(10, 10, device=device, dtype=dtype)) for _ in range(10)] + + if dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16] and device == 'cuda:0': + # value cannot be converted to dtype without overflow: + self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar)) + self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, complex_scalar)) + return + + res = torch._foreach_add(tensors, complex_scalar) + self.assertEqual(res, expected) + + if dtype not in [torch.complex64, torch.complex128]: + self.assertRaises(RuntimeError, lambda: torch._foreach_add_(tensors, complex_scalar)) + else: + torch._foreach_add_(tensors, complex_scalar) + self.assertEqual(res, tensors) + + @dtypes(*torch.testing.get_all_dtypes()) + def test_bool_scalar(self, device, dtype): + tensors = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)] + bool_scalar = True + + expected = [torch.ones(10, 10, device=device, dtype=dtype) for _ in range(10)] + + res = torch._foreach_add(tensors, bool_scalar) + self.assertEqual(res, expected) + + torch._foreach_add_(tensors, bool_scalar) + self.assertEqual(res, tensors) + + @dtypes(*torch.testing.get_all_dtypes()) + def test_add_scalar_with_different_size_tensors(self, device, dtype): + if dtype == torch.bool: + return - size_change = 0 - for t in res: - if dtype == torch.bool: - dtype = torch.int64 - self.assertEqual(t, torch.ones(H + size_change, W + size_change, device=device, dtype=dtype)) - size_change += 1 + tensors = [torch.zeros(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)] + expected = [torch.ones(10 + n, 10 + n, device=device, dtype=dtype) for n in range(10)] + torch._foreach_add_(tensors, 1) + self.assertEqual(expected, tensors) @dtypes(*torch.testing.get_all_dtypes()) - def test_add_scalar_with_empty_list(self, device, dtype): - tensors = [] - with self.assertRaises(RuntimeError): - torch._foreach_add(tensors, 1) + def test_add_scalar_with_empty_list_and_empty_tensor(self, device, dtype): + # TODO: enable empty list case + for tensors in [[torch.randn([0])]]: + res = torch._foreach_add(tensors, 1) + self.assertEqual(res, tensors) + + torch._foreach_add_(tensors, 1) + self.assertEqual(res, tensors) @dtypes(*torch.testing.get_all_dtypes()) def test_add_scalar_with_overlapping_tensors(self, device, dtype): tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)] expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)] + # bool tensor + 1 will result in int64 tensor if dtype == torch.bool: expected[0] = expected[0].to(torch.int64).add(1) @@ -58,43 +115,9 @@ def test_add_scalar_with_overlapping_tensors(self, device, dtype): self.assertEqual(res, expected) def test_add_scalar_with_different_tensor_dtypes(self, device): - tensors = [torch.tensor([1], dtype=torch.float, device=device), - torch.tensor([1], dtype=torch.int, device=device)] - - expected = [torch.tensor([2], dtype=torch.float, device=device), - torch.tensor([2], dtype=torch.int, device=device)] - - res = torch._foreach_add(tensors, 1) - self.assertEqual(res, expected) - - def test_add_scalar_with_different_scalar_type(self, device): - # int tensor with float scalar - # should go 'slow' route - scalar = 1.1 - tensors = [torch.tensor([1], dtype=torch.int, device=device)] - res = torch._foreach_add(tensors, scalar) - self.assertEqual(res, [torch.tensor([2.1], device=device)]) - - # float tensor with int scalar - # should go 'fast' route - scalar = 1 - tensors = [torch.tensor([1.1], device=device)] - res = torch._foreach_add(tensors, scalar) - self.assertEqual(res, [torch.tensor([2.1], device=device)]) - - # bool tensor with int scalar - # should go 'slow' route - scalar = 1 - tensors = [torch.tensor([False], device=device)] - res = torch._foreach_add(tensors, scalar) - self.assertEqual(res, [torch.tensor([1], device=device)]) - - # bool tensor with float scalar - # should go 'slow' route - scalar = 1.1 - tensors = [torch.tensor([False], device=device)] - res = torch._foreach_add(tensors, scalar) - self.assertEqual(res, [torch.tensor([1.1], device=device)]) + tensors = [torch.tensor([1.1], dtype=torch.float, device=device), + torch.tensor([1], dtype=torch.long, device=device)] + self.assertRaises(RuntimeError, lambda: torch._foreach_add(tensors, 1)) instantiate_device_type_tests(TestForeach, globals()) diff --git a/test/test_fx.py b/test/test_fx.py index 4b863acb7110dcc..a3b994af4570048 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -320,10 +320,10 @@ def __init__(self, interpreter): # Add placeholders for fn inputs placeholder_nodes = [] for name in fn_input_names: - placeholder_nodes.append(graph.placeholder(name)) + placeholder_nodes.append(graph.create_node('placeholder', name)) # Get the interpreter object - interpreter_node = graph.get_param('interpreter') + interpreter_node = graph.create_node('get_param', 'interpreter') # Add a node to call the interpreter instance output_node = graph.create_node( @@ -355,5 +355,35 @@ def __init__(self, interpreter): imported_out = import_copy(x) torch.testing.assert_allclose(imported_out, ref_out) + def test_reserved_getattr(self): + """Ensure that we do not name any nodes with a reserved builtin like `getattr`""" + class M(torch.nn.Module): + def forward(self, a): + return a.foo.bar.baz + + m = M() + m_g = symbolic_trace(m) + for node in m_g.graph.nodes: + self.assertTrue(node.name != "getattr") + + def test_node_tagging(self): + class TaggingDelegate(DefaultDelegate): + def create_node(self, kind : str, target : Union[str, Callable], + args : Tuple[Any], kwargs : Dict[str, Any], name : Optional[str] = None) -> Node: + n = super().create_node(kind, target, args, kwargs, name) + n.tag = 'foo' + return n + + class M(torch.nn.Module): + def forward(self, a, b): + return a + b + + m = M() + g = symbolic_trace(m, TaggingDelegate).graph + for n in g.nodes: + self.assertTrue(hasattr(n, 'tag')) + self.assertEqual(n.tag, 'foo') + + if __name__ == '__main__': run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index 99bd1194719704b..358efca36a709bf 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,6 +30,7 @@ from jit.test_onnx_export import TestONNXExport # noqa: F401 from jit.test_with import TestWith # noqa: F401 from jit.test_enum import TestEnum, TestEnumFeatureGuard # noqa: F401 +from jit.test_profiler import TestProfiler # noqa: F401 # Torch from torch import Tensor @@ -5270,7 +5271,6 @@ def def_in_one_branch(x, z): # this triggers 2 bailouts self.assertEqual(def_in_one_branch(a, True), 3.0) - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") def test_maxpool_guard_elimination(self): @torch.jit.script @@ -5670,6 +5670,16 @@ def fn(x, y, z): ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) self.assertExpected(str(ast)) + def test_python_frontend_source_range(self): + def fn(): + raise Exception("hello") + ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) + FileCheck().check("SourceRange at:") \ + .check("def fn():") \ + .check("~~~~~~~~~... <--- HERE") \ + .check('raise Exception("hello")') \ + .run(str(ast.range())) + def test_python_frontend_py3(self): def fn(): raise Exception("hello") @@ -12395,7 +12405,7 @@ def close_match(x): "supported in TorchScript"): @torch.jit.script def unknown_op(x): - torch.set_grad_enabled(True) + torch.set_anomaly_enabled(True) return x def test_exceptions(self): diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 5f0989d2c7bcfa8..ced6b7e50271221 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -12,7 +12,7 @@ from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, _inline_everything, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU + RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward from textwrap import dedent from itertools import product, permutations @@ -29,19 +29,6 @@ def strip_profiling_nodes(nodes): return [n for n in nodes if n.kind() not in profiling_opcodes] -def warmup_backward(f, *args): - profiling_count = 2 - results = [] - for i in range(profiling_count): - if len(args) > 0: - r = torch.autograd.grad(f, *args) - results.append(r) - else: - f.backward(retain_graph=True) - - return results - - def warmup_forward(f, *args): profiling_count = 2 for i in range(profiling_count): diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index 6726133508bc919..a8a394744122649 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -21,7 +21,7 @@ from torch.testing._internal.common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \ enable_profiling_mode_for_profiling_tests from torch.testing._internal.jit_utils import JitTestCase, _inline_everything, \ - RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU + RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward from textwrap import dedent from itertools import product, permutations @@ -37,20 +37,6 @@ def strip_profiling_nodes(nodes): profiling_opcodes = set(['prim::BailoutTemplate', 'prim::BailOut']) return [n for n in nodes if n.kind() not in profiling_opcodes] - -def warmup_backward(f, *args): - profiling_count = 2 - results = [] - for i in range(profiling_count): - if len(args) > 0: - r = torch.autograd.grad(f, *args) - results.append(r) - else: - f.backward(retain_graph=True) - - return results - - def warmup_forward(f, *args): profiling_count = 2 for i in range(profiling_count): @@ -58,7 +44,6 @@ def warmup_forward(f, *args): return results - class TestTEFuser(JitTestCase): def setUp(self): self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() @@ -117,6 +102,45 @@ def func(x): self.assertEqual(len(fusion_groups), 1) FileCheck().check("aten::abs").check("aten::mul").run(str(fusion_groups[0])) + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") + def test_sum_simple(self): + def func(x): + return x.sum() * 2 + + a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = a.reshape(5, 3) + scripted = self.checkScript(func, (a,)) + graph = scripted.graph_for(a) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + self.assertEqual(scripted(a), func(a)) + + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") + def test_sum_dim(self): + def func(x): + return x.sum((0, )) * 2 + + a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = a.reshape(5, 3) + scripted = self.checkScript(func, (a,)) + graph = scripted.graph_for(a) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + self.assertEqual(scripted(a), func(a)) + + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") + def test_sum_keepdim_cast(self): + def func(x): + return x.sum((0, ), keepdim=True, dtype=torch.double) * 2 + + a = torch.tensor(list(x for x in range(0, 15)), dtype=torch.float, device='cpu') + a = a.reshape(5, 3) + scripted = self.checkScript(func, (a,)) + graph = scripted.graph_for(a) + fusion_groups = self.findFusionGroups(graph) + self.assertEqual(len(fusion_groups), 1) + self.assertEqual(scripted(a), func(a)) + @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") def test_abs_cpu(self): self._test_fused_abs() diff --git a/test/test_linalg.py b/test/test_linalg.py index 73e63c19b646210..1e3d8d94374e4c4 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1,10 +1,14 @@ import torch import unittest +import itertools +from math import inf, nan, isnan from torch.testing._internal.common_utils import \ (TestCase, run_tests, TEST_NUMPY) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack) +from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args +from torch.autograd import gradcheck if TEST_NUMPY: import numpy as np @@ -54,6 +58,481 @@ def test_det(self, device, dtype): with self.assertRaises(IndexError): op(t) + # This test confirms that torch.linalg.norm's dtype argument works + # as expected, according to the function's documentation + def test_norm_dtype(self, device): + def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype): + msg = ( + f'input_size={input_size}, ord={ord}, keepdim={keepdim}, ' + f'from_dtype={from_dtype}, to_dtype={to_dtype}') + input = torch.randn(*input_size, dtype=from_dtype, device=device) + result = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=from_dtype) + self.assertEqual(result.dtype, from_dtype, msg=msg) + result_converted = torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype) + self.assertEqual(result_converted.dtype, to_dtype, msg=msg) + self.assertEqual(result.to(compare_dtype), result_converted.to(compare_dtype), msg=msg) + + result_out_converted = torch.empty_like(result_converted) + torch.linalg.norm(input, ord, keepdim=keepdim, dtype=to_dtype, out=result_out_converted) + self.assertEqual(result_out_converted.dtype, to_dtype, msg=msg) + self.assertEqual(result_converted, result_out_converted, msg=msg) + + ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + ord_matrix = [1, -1, 2, -2, inf, -inf, None] + S = 10 + test_cases = [ + ((S, ), ord_vector), + ((S, S), ord_matrix), + ] + for keepdim in [True, False]: + for input_size, ord_settings in test_cases: + for ord in ord_settings: + # float to double + run_test_case(input_size, ord, keepdim, torch.float, torch.double, torch.float) + # double to float + run_test_case(input_size, ord, keepdim, torch.double, torch.double, torch.float) + + # Make sure that setting dtype != out.dtype raises an error + dtype_pairs = [ + (torch.float, torch.double), + (torch.double, torch.float), + ] + for keepdim in [True, False]: + for input_size, ord_settings in test_cases: + for ord in ord_settings: + for dtype, out_dtype in dtype_pairs: + input = torch.rand(*input_size) + result = torch.Tensor().to(out_dtype) + with self.assertRaisesRegex(RuntimeError, r'provided dtype must match dtype of result'): + torch.linalg.norm(input, ord=ord, keepdim=keepdim, dtype=dtype, out=result) + + # TODO: Once dtype arg is supported in nuclear and frobenius norms, remove the following test + # and add 'nuc' and 'fro' to ord_matrix above + for ord in ['nuc', 'fro']: + input = torch.randn(10, 10, device=device) + with self.assertRaisesRegex(RuntimeError, f"ord=\'{ord}\' does not yet support the dtype argument"): + torch.linalg.norm(input, ord, dtype=torch.float) + + # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that + # their vector norm results match + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_vector(self, device, dtype): + def run_test_case(input, p, dim, keepdim): + result = torch.linalg.norm(input, ord, dim, keepdim) + input_numpy = input.cpu().numpy() + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + self.assertEqual(result, result_numpy, msg=msg) + + result_out = torch.empty_like(result) + torch.linalg.norm(input, ord, dim, keepdim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + + ord_vector = [0, 1, -1, 2, -2, 3, -3, 4.5, -4.5, inf, -inf, None] + S = 10 + test_cases = [ + # input size, p settings, dim + ((S, ), ord_vector, None), + ((S, ), ord_vector, (0, )), + ((S, S, S), ord_vector, (0, )), + ((S, S, S), ord_vector, (1, )), + ((S, S, S), ord_vector, (2, )), + ((S, S, S), ord_vector, (-1, )), + ((S, S, S), ord_vector, (-2, )), + ] + L = 1_000_000 + if dtype == torch.double: + test_cases.append(((L, ), ord_vector, None)) + for keepdim in [True, False]: + for input_size, ord_settings, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_test_case(input, ord, dim, keepdim) + + # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that + # their matrix norm results match + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_matrix(self, device, dtype): + def run_test_case(input, p, dim, keepdim): + result = torch.linalg.norm(input, ord, dim, keepdim) + input_numpy = input.cpu().numpy() + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + self.assertEqual(result, result_numpy, msg=msg) + + result_out = torch.empty_like(result) + torch.linalg.norm(input, ord, dim, keepdim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + + ord_matrix = [1, -1, 2, -2, inf, -inf, 'nuc', 'fro', None] + S = 10 + test_cases = [ + # input size, p settings, dim + ((S, S), ord_matrix, None), + ((S, S), ord_matrix, (0, 1)), + ((S, S), ord_matrix, (1, 0)), + ((S, S, S, S), ord_matrix, (2, 0)), + ((S, S, S, S), ord_matrix, (-1, -2)), + ((S, S, S, S), ord_matrix, (-1, -3)), + ((S, S, S, S), ord_matrix, (-3, 2)), + ] + L = 1_000 + if dtype == torch.double: + test_cases.append(((L, L), ord_matrix, None)) + for keepdim in [True, False]: + for input_size, ord_settings, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_test_case(input, ord, dim, keepdim) + + # Test autograd and jit functionality for linalg functions. + # TODO: Once support for linalg functions is added to method_tests in common_methods_invocations.py, + # the `test_cases` entries below should be moved there. These entries are in a similar format, + # so they should work with minimal changes. + @dtypes(torch.float, torch.double) + def test_autograd_and_jit(self, device, dtype): + torch.manual_seed(0) + S = 10 + NO_ARGS = None # NOTE: refer to common_methods_invocations.py if you need this feature + test_cases = [ + # NOTE: Not all the features from common_methods_invocations.py are functional here, since this + # is only a temporary solution. + # ( + # method name, + # input size/constructing fn, + # args (tuple represents shape of a tensor arg), + # test variant name (will be used at test name suffix), // optional + # (should_check_autodiff[bool], nonfusible_nodes, fusible_nodes) for autodiff, // optional + # indices for possible dim arg, // optional + # fn mapping output to part that should be gradcheck'ed, // optional + # kwargs // optional + # ) + ('norm', (S,), (), 'default_1d'), + ('norm', (S, S), (), 'default_2d'), + ('norm', (S, S, S), (), 'default_3d'), + ('norm', (S,), (inf,), 'vector_inf'), + ('norm', (S,), (3.5,), 'vector_3_5'), + ('norm', (S,), (2,), 'vector_2'), + ('norm', (S,), (1,), 'vector_1'), + ('norm', (S,), (0,), 'vector_0'), + ('norm', (S,), (-inf,), 'vector_neg_inf'), + ('norm', (S,), (-3.5,), 'vector_neg_3_5'), + ('norm', (S,), (2,), 'vector_neg_2'), + ('norm', (S,), (1,), 'vector_neg_1'), + ('norm', (S, S), (inf,), 'matrix_inf'), + ('norm', (S, S), (2,), 'matrix_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), (1,), 'matrix_1'), + ('norm', (S, S), (-inf,), 'matrix_neg_inf'), + ('norm', (S, S), (-2,), 'matrix_neg_2', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), (-1,), 'matrix_neg_1'), + ('norm', (S, S), ('fro',), 'fro'), + ('norm', (S, S), ('fro', [0, 1]), 'fro_dim'), + ('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ('norm', (S, S), ('nuc', [0, 1]), 'nuc_dim', (), NO_ARGS, [skipCPUIfNoLapack, skipCUDAIfNoMagma]), + ] + for test_case in test_cases: + func_name = test_case[0] + func = getattr(torch.linalg, func_name) + input_size = test_case[1] + args = list(test_case[2]) + test_case_name = test_case[3] if len(test_case) >= 4 else None + mapping_funcs = list(test_case[6]) if len(test_case) >= 7 else None + + # Skip a test if a decorator tells us to + if mapping_funcs is not None: + def decorated_func(self, device, dtype): + pass + for mapping_func in mapping_funcs: + decorated_func = mapping_func(decorated_func) + try: + decorated_func(self, device, dtype) + except unittest.SkipTest: + continue + + msg = f'function name: {func_name}, case name: {test_case_name}' + + # Test JIT + input = torch.randn(*input_size, dtype=dtype, device=device) + input_script = input.clone().detach() + script_method, tensors = gen_script_fn_and_args("linalg.norm", "functional", input_script, *args) + self.assertEqual( + func(input, *args), + script_method(input_script), + msg=msg) + + # Test autograd + # gradcheck is only designed to work with torch.double inputs + if dtype == torch.double: + input = torch.randn(*input_size, dtype=dtype, device=device, requires_grad=True) + + def run_func(input): + return func(input, *args) + self.assertTrue(gradcheck(run_func, input), msg=msg) + + # This test calls torch.linalg.norm and numpy.linalg.norm with illegal arguments + # to ensure that they both throw errors + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.float, torch.double) + def test_norm_errors(self, device, dtype): + def run_error_test_case(input, ord, dim, keepdim, error_type, error_regex): + test_case_info = ( + f'test case input.size()={input.size()}, ord={ord}, dim={dim}, ' + f'keepdim={keepdim}, dtype={dtype}') + + with self.assertRaisesRegex(error_type, error_regex, msg=test_case_info): + torch.linalg.norm(input, ord, dim, keepdim) + + input_numpy = input.cpu().numpy() + + msg = f'numpy does not raise error but pytorch does, for case "{test_case_info}"' + with self.assertRaises(Exception, msg=test_case_info): + np.linalg.norm(input_numpy, ord, dim, keepdim) + + S = 10 + error_test_cases = [ + # input size, p settings, dim, error type, error regex + ((S, ), ['fro'], None, RuntimeError, r'order "fro" can only be used if either len\(dim\) == 2'), + ((S, ), ['nuc'], None, RuntimeError, r'order "nuc" can only be used if either len\(dim\) == 2'), + ((S, S), [3.5], None, RuntimeError, r'Order 3.5 not supported for matrix norm'), + ((S, S), [0], None, RuntimeError, r'Order 0 not supported for matrix norm'), + ((S, S), ['nuc'], (0, ), RuntimeError, r'order "nuc" can only be used if either len\(dim\) == 2'), + ((S, S), ['fro'], (0, ), RuntimeError, r'order "fro" can only be used if either len\(dim\) == 2'), + ((S, S), ['nuc'], (0, 0), RuntimeError, r'duplicate or invalid dimensions'), + ((S, S), ['fro', 0], (0, 0), RuntimeError, r'Expected dims to be different'), + ((S, S), ['fro', 'nuc', 0], (0, 4), IndexError, r'Dimension out of range'), + ((S, ), [0], (4, ), IndexError, r'Dimension out of range'), + ((S, ), [None], (0, 0), RuntimeError, r'Expected dims to be different, got this instead'), + ((S, S, S), [1], (0, 1, 2), RuntimeError, r"'dim' must specify 1 or 2 dimensions"), + ((S, S, S), [1], None, RuntimeError, r"'dim' must specify 1 or 2 dimensions"), + ((S, S), ['garbage'], (0, 1), RuntimeError, r'Invalid norm order: garbage'), + ] + for keepdim in [True, False]: + for input_size, ord_settings, dim, error_type, error_regex in error_test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_settings: + run_error_test_case(input, ord, dim, keepdim, error_type, error_regex) + + # Test complex number inputs for linalg.norm. Some cases are not supported yet, so + # this test also verifies that those cases raise an error. + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.cfloat, torch.cdouble) + def test_norm_complex(self, device, dtype): + def gen_error_message(input_size, ord, keepdim, dim=None): + return "complex norm failed for input size %s, ord=%s, keepdim=%s, dim=%s" % ( + input_size, ord, keepdim, dim) + + if self.device_type == 'cpu': + supported_vector_ords = [0, 1, 3, inf, -1, -2, -3, -inf] + supported_matrix_ords = ['nuc', 1, 2, inf, -1, -2, -inf] + unsupported_vector_ords = [ + (2, r'norm with p=2 not supported for complex tensors'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + unsupported_matrix_ords = [ + ('fro', r'frobenius norm not supported for complex tensors'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + + elif self.device_type == 'cuda': + supported_vector_ords = [inf, -inf] + supported_matrix_ords = [1, inf, -1, -inf] + unsupported_vector_ords = [ + (0, r'norm_cuda" not implemented for \'Complex'), + (1, r'norm_cuda" not implemented for \'Complex'), + (2, r'norm with p=2 not supported for complex tensors'), + (-1, r'norm_cuda" not implemented for \'Complex'), + (-2, r'norm_cuda" not implemented for \'Complex'), + (None, r'norm with p=2 not supported for complex tensors'), + ] + unsupported_matrix_ords = [ + (None, r'norm with p=2 not supported for complex tensors'), + ('fro', r'frobenius norm not supported for complex tensors'), + (2, r'"svd_cuda" not implemented for \'Complex'), + (-2, r'"svd_cuda" not implemented for \'Complex'), + ('nuc', r'"svd_cuda" not implemented for \'Complex'), + ] + + # Test supported ords + for keepdim in [False, True]: + # vector norm + x = torch.randn(25, device=device, dtype=dtype) + xn = x.cpu().numpy() + for ord in supported_vector_ords: + res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, ord, keepdims=keepdim) + msg = gen_error_message(x.size(), ord, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # matrix norm + x = torch.randn(25, 25, device=device, dtype=dtype) + xn = x.cpu().numpy() + for ord in supported_matrix_ords: + # TODO: Need to fix abort when nuclear norm is given cdouble input: + # "double free or corruption (!prev) Aborted (core dumped)" + if ord == 'nuc' and dtype == torch.cdouble: + continue + res = torch.linalg.norm(x, ord, keepdim=keepdim).cpu() + expected = np.linalg.norm(xn, ord, keepdims=keepdim) + msg = gen_error_message(x.size(), ord, keepdim) + self.assertEqual(res.shape, expected.shape, msg=msg) + self.assertEqual(res, expected, msg=msg) + + # Test unsupported ords + # vector norm + x = torch.randn(25, device=device, dtype=dtype) + for ord, error_msg in unsupported_vector_ords: + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.linalg.norm(x, ord) + + # matrix norm + x = torch.randn(25, 25, device=device, dtype=dtype) + for ord, error_msg in unsupported_matrix_ords: + with self.assertRaisesRegex(RuntimeError, error_msg): + torch.linalg.norm(x, ord) + + # Make sure that linalg.norm raises an error if dim is an integer + # TODO: When integer dims are supported in norm, remove this test + def test_norm_dim_int_error(self, device): + input = torch.randn(10, device=device) + with self.assertRaisesRegex(TypeError, r'linalg_norm\(\) received an invalid combination of arguments'): + torch.linalg.norm(input, dim=0) + + # Test that linal.norm gives the same result as numpy when inputs + # contain extreme values (inf, -inf, nan) + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + def test_norm_extreme_values(self, device): + vector_ords = [0, 1, 2, 3, inf, -1, -2, -3, -inf] + matrix_ords = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf] + vectors = [] + matrices = [] + for pair in itertools.product([inf, -inf, 0.0, nan, 1.0], repeat=2): + vectors.append(list(pair)) + matrices.append([[pair[0], pair[1]]]) + matrices.append([[pair[0]], [pair[1]]]) + for vector in vectors: + x = torch.tensor(vector).to(device) + x_n = x.cpu().numpy() + for ord in vector_ords: + msg = f'ord={ord}, vector={vector}' + result = torch.linalg.norm(x, ord=ord) + result_n = np.linalg.norm(x_n, ord=ord) + self.assertEqual(result, result_n, msg=msg) + + # TODO: Remove this function once the broken cases are fixed + def is_broken_matrix_norm_case(ord, x): + if self.device_type == 'cuda': + if x.size() == torch.Size([1, 2]): + if ord in ['nuc', 2, -2] and isnan(x[0][0]) and x[0][1] == 1: + # These cases are broken because of an issue with svd + # https://github.com/pytorch/pytorch/issues/43567 + return True + return False + + for matrix in matrices: + x = torch.tensor(matrix).to(device) + x_n = x.cpu().numpy() + for ord in matrix_ords: + msg = f'ord={ord}, matrix={matrix}' + result = torch.linalg.norm(x, ord=ord) + result_n = np.linalg.norm(x_n, ord=ord) + + if is_broken_matrix_norm_case(ord, x): + self.assertNotEqual(result, result_n, msg=msg) + else: + self.assertEqual(result, result_n, msg=msg) + + # Test degenerate shape results match numpy for linalg.norm vector norms + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_vector_degenerate_shapes(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + if should_error: + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + else: + if dtype in [torch.cfloat, torch.cdouble] and ord in [2, None]: + # TODO: Once these ord values have support for complex numbers, + # remove this error test case + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + return + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + result = torch.linalg.norm(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_vector = [0, 0.5, 1, 2, 3, inf, -0.5, -1, -2, -3, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, dim + ((0, ), [inf, -inf], None), + ((0, S), [inf, -inf], (0,)), + ((0, S), [], (1,)), + ((S, 0), [], (0,)), + ((S, 0), [inf, -inf], (1,)), + ] + for keepdim in [True, False]: + for input_size, error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_vector: + run_test_case(input, ord, dim, keepdim, ord in error_ords) + + # Test degenerate shape results match numpy for linalg.norm matrix norms + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @unittest.skipIf(not TEST_NUMPY, "Numpy not found") + @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + def test_norm_matrix_degenerate_shapes(self, device, dtype): + def run_test_case(input, ord, dim, keepdim, should_error): + if dtype in [torch.cfloat, torch.cdouble] and ord in ['fro', None]: + # TODO: Once these ord values have support for complex numbers, + # remove this error test case + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + return + msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}' + input_numpy = input.cpu().numpy() + if should_error: + with self.assertRaises(ValueError): + np.linalg.norm(input_numpy, ord, dim, keepdim) + with self.assertRaises(RuntimeError): + torch.linalg.norm(input, ord, dim, keepdim) + else: + result_numpy = np.linalg.norm(input_numpy, ord, dim, keepdim) + result = torch.linalg.norm(input, ord, dim, keepdim) + self.assertEqual(result, result_numpy, msg=msg) + + ord_matrix = ['fro', 'nuc', 1, 2, inf, -1, -2, -inf, None] + S = 10 + test_cases = [ + # input size, p settings that cause error, dim + ((0, 0), [1, 2, inf, -1, -2, -inf], None), + ((0, S), [2, inf, -2, -inf], None), + ((S, 0), [1, 2, -1, -2], None), + ((S, S, 0), [], (0, 1)), + ((1, S, 0), [], (0, 1)), + ((0, 0, S), [1, 2, inf, -1, -2, -inf], (0, 1)), + ((0, 0, S), [1, 2, inf, -1, -2, -inf], (1, 0)), + ] + for keepdim in [True, False]: + for input_size, error_ords, dim in test_cases: + input = torch.randn(*input_size, dtype=dtype, device=device) + for ord in ord_matrix: + run_test_case(input, ord, dim, keepdim, ord in error_ords) instantiate_device_type_tests(TestLinalg, globals()) diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index b162e94d5f9cbfc..dd10ce6017552b0 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -273,6 +273,33 @@ def test_max_pool3d(self): max_pool3d(x), max_pool3d(x.to_mkldnn()).to_dense()) + def test_max_pool_unsupported(self): + # OneDNN not support dilation max_pooling, will be avilabled in v2.0. + N = torch.randint(3, 10, (1,)).item() + C = torch.randint(3, 10, (1,)).item() + + # 2d dilation case + x = torch.randn(N, C, 7, 7, dtype=torch.float32).to_mkldnn() + max_pool2d = torch.nn.MaxPool2d( + kernel_size=3, + stride=3, + padding=1, + dilation=2) + self.assertRaisesRegex(RuntimeError, + 'mkldnn_max_pool2d does not support dilation case', + lambda: max_pool2d(x)) + + # 3d dilation case + x = torch.randn(N, C, 7, 7, 7, dtype=torch.float32).to_mkldnn() + max_pool3d = torch.nn.MaxPool3d( + kernel_size=3, + stride=3, + padding=1, + dilation=2) + self.assertRaisesRegex(RuntimeError, + 'mkldnn_max_pool3d does not support dilation case', + lambda: max_pool3d(x)) + def test_avg_pool2d(self): N = torch.randint(3, 10, (1,)).item() C = torch.randint(3, 10, (1,)).item() diff --git a/test/test_nn.py b/test/test_nn.py index 49ca68a6d6d7533..376734415726adb 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -11913,6 +11913,45 @@ def __init__(self): for p, pe in zip(test_model.parameters(), ref_model.parameters()): self.assertEqual(p.grad.to(devices[0]), pe.grad) + def test_elu_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.elu(x, inplace=True) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.elu_(x) + + def test_hardswish_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.hardswish(x, inplace=True) + + def test_silu_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.silu(x, inplace=True) + + def test_softplus_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.softplus(x, out=x) + + def test_softshrink_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.softshrink(x, out=x) + + def test_leaky_relu_inplace_overlap(self, device): + x = torch.randn((1, 6), device=device).expand((6, 6)) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.leaky_relu(x, inplace=True) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + F.leaky_relu_(x) + + def test_threshold_inplace_overlap(self, device): + # Inplace threshold is okay, because it is idempotent + x = torch.randn((1, 6), device=device).expand((6, 6)) + F.threshold(x, 0.5, 0.5, inplace=True) + F.threshold_(x, 0.5, 0.5) class TestModuleGlobalHooks(TestCase): diff --git a/test/test_op_aliases.py b/test/test_op_aliases.py index 2dfd739889acb90..e00c0f7e7590161 100644 --- a/test/test_op_aliases.py +++ b/test/test_op_aliases.py @@ -72,6 +72,22 @@ def __init__(self, lambda d: 10 * torch.randn(20, device=d)), AliasInfo('negative_', torch.Tensor.negative_, 'neg_', torch.Tensor.neg_, lambda d: 10 * torch.randn(20, device=d)), + AliasInfo('arcsinh', torch.arcsinh, 'asinh', torch.asinh, + lambda d: torch.randn(20, device=d)), + AliasInfo('arcsinh_', torch.Tensor.arcsinh_, 'asinh_', torch.Tensor.asinh_, + lambda d: torch.randn(20, device=d)), + AliasInfo('arctanh', torch.arctanh, 'atanh', torch.atanh, + lambda d: torch.clamp(torch.randn(20, device=d), -1, 1)), + AliasInfo('arctanh_', torch.Tensor.arctanh_, 'atanh_', torch.Tensor.atanh_, + lambda d: torch.clamp(torch.randn(20, device=d), -1, 1)), + AliasInfo('subtract', torch.subtract, 'sub', torch.sub, + lambda d: torch.randn(20, device=d), + get_args=lambda d: (torch.randn(20, device=d),), + decorators=(onlyCPU,)), + AliasInfo('subtract_', torch.Tensor.subtract_, 'sub_', torch.Tensor.sub_, + lambda d: torch.randn(20, device=d), + get_args=lambda d: (torch.randn(20, device=d),), + decorators=(onlyCPU,)), ) # Placeholder test class for validating that aliases are correctly diff --git a/test/test_pruning_op.py b/test/test_pruning_op.py new file mode 100644 index 000000000000000..17b592b34717e5d --- /dev/null +++ b/test/test_pruning_op.py @@ -0,0 +1,78 @@ +from __future__ import absolute_import, division, print_function, unicode_literals + +import hypothesis.strategies as st +from hypothesis import given +import numpy as np +import torch +from torch.testing._internal.common_utils import TestCase +import torch.testing._internal.hypothesis_utils as hu +hu.assert_deadline_disabled() + + +class PruningOpTest(TestCase): + + # Generate rowwise mask vector based on indicator and threshold value. + # indicator is a vector that contains one value per weight row and it + # represents the importance of a row. + # We mask a row if its indicator value is less than the threshold. + def _generate_rowwise_mask(self, embedding_rows): + indicator = torch.from_numpy((np.random.random_sample(embedding_rows)).astype(np.float32)) + threshold = np.random.random_sample() + mask = torch.BoolTensor([True if val >= threshold else False for val in indicator]) + return mask + + def _test_rowwise_prune_op(self, embedding_rows, embedding_dims, indices_type, weights_dtype): + embedding_weights = None + if weights_dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + embedding_weights = torch.randint(0, 100, (embedding_rows, embedding_dims), dtype=weights_dtype) + else: + embedding_weights = torch.rand((embedding_rows, embedding_dims), dtype=weights_dtype) + mask = self._generate_rowwise_mask(embedding_rows) + + def get_pt_result(embedding_weights, mask, indices_type): + return torch.rowwise_prune(embedding_weights, mask, indices_type) + + # Reference implementation. + def get_reference_result(embedding_weights, mask, indices_type): + num_embeddings = mask.size()[0] + compressed_idx_out = torch.zeros(num_embeddings, dtype=indices_type) + pruned_weights_out = embedding_weights[mask[:]] + idx = 0 + for i in range(mask.size()[0]): + if mask[i]: + compressed_idx_out[i] = idx + idx = idx + 1 + else: + compressed_idx_out[i] = -1 + return (pruned_weights_out, compressed_idx_out) + + pt_pruned_weights, pt_compressed_indices_map = get_pt_result( + embedding_weights, mask, indices_type) + ref_pruned_weights, ref_compressed_indices_map = get_reference_result( + embedding_weights, mask, indices_type) + + torch.testing.assert_allclose(pt_pruned_weights, ref_pruned_weights) + self.assertEqual(pt_compressed_indices_map, ref_compressed_indices_map) + self.assertEqual(pt_compressed_indices_map.dtype, indices_type) + + + @given( + embedding_rows=st.integers(1, 100), + embedding_dims=st.integers(1, 100), + weights_dtype=st.sampled_from([torch.float64, torch.float32, + torch.float16, torch.int8, + torch.int16, torch.int32, torch.int64]) + ) + def test_rowwise_prune_op_32bit_indices(self, embedding_rows, embedding_dims, weights_dtype): + self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int, weights_dtype) + + + @given( + embedding_rows=st.integers(1, 100), + embedding_dims=st.integers(1, 100), + weights_dtype=st.sampled_from([torch.float64, torch.float32, + torch.float16, torch.int8, + torch.int16, torch.int32, torch.int64]) + ) + def test_rowwise_prune_op_64bit_indices(self, embedding_rows, embedding_dims, weights_dtype): + self._test_rowwise_prune_op(embedding_rows, embedding_dims, torch.int64, weights_dtype) diff --git a/test/test_quantization.py b/test/test_quantization.py index a0db0fd50cb8246..0aeba1f08586d4a 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -63,6 +63,7 @@ # 3. GraphModule based graph mode quantization from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401 from quantization.test_quantize_fx import TestQuantizeFxOps # noqa: F401 +from quantization.test_quantize_fx import TestQuantizeFxModels # noqa: F401 # Tooling: numric_suite from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401 diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 27d8213840a2b00..4d51f4095095866 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -2,6 +2,8 @@ from torch import nn import numpy as np +from torch.testing._internal.common_utils import TestCase, run_tests + class StaticRuntime: def __init__(self, scripted): @@ -90,50 +92,56 @@ def trivial_graph(a, b, c): s = torch.tensor([[3, 3], [3, 3]]) return a + b * c + s +class TestStaticRuntime(TestCase): + def test_multihead_attention_layer(self): + HID_DIM = 256 + QUERY_LEN = 8 + BATCH_SIZE = 128 + LAYERS = 3 + HEADS = 8 + DROPOUT = 0.1 + device = torch.device("cpu") + attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) + src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) + src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) + + attention.eval() + attention = torch.jit.script(attention) + attention.eval() + o_ref = attention(src, src, src, src_mask) + + attention_a = StaticRuntime(attention) + o_test = attention_a(src, src, src, src_mask) + for a, b in zip(o_ref, o_test): + torch.testing.assert_allclose(a, b) + + def test_mlp(self): + # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh + ln_bot = [512, 512, 64] + sigmoid_bot = -1 + ln_top = [100, 1024, 1024, 1024, 1] + sigmoid_top = 3 + bot_l = create_mlp(ln_bot, sigmoid_bot) + bot_l_acc = StaticRuntime(bot_l) + top_l = create_mlp(ln_top, sigmoid_top) + top_l_acc = StaticRuntime(top_l) + bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) + top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) + ref_bot = bot_l(bot_inp) + acc_bot = bot_l_acc(bot_inp)[0] + torch.testing.assert_allclose(acc_bot, ref_bot) + ref_top = top_l(top_inp) + acc_top = top_l_acc(top_inp)[0] + torch.testing.assert_allclose(acc_top, ref_top) + + + # def test_trivial_graph(self): + # s = torch.full((2, 2), 2) + # tg = torch.jit.script(trivial_graph) + # o_ref = tg(s, s, s) + # tg_a = StaticRuntime(tg) + # o_test = tg_a(s, s, s)[0] + # torch.testing.assert_allclose(o_ref, o_test) if __name__ == "__main__": - HID_DIM = 256 - QUERY_LEN = 8 - BATCH_SIZE = 128 - LAYERS = 3 - HEADS = 8 - DROPOUT = 0.1 - device = torch.device("cpu") - attention = MultiHeadAttentionLayer(HID_DIM, HEADS, DROPOUT, device).to(device) - src = torch.randn(BATCH_SIZE, QUERY_LEN, HID_DIM).to(device) - src_mask = (src > 0)[:, :, 0].unsqueeze(1).unsqueeze(2).to(device) - - attention.eval() - attention = torch.jit.script(attention) - attention.eval() - o_ref = attention(src, src, src, src_mask) - - attention_a = StaticRuntime(attention) - o_test = attention_a(src, src, src, src_mask) - for a, b in zip(o_ref, o_test): - torch.testing.assert_allclose(a, b) - - s = torch.full((2, 2), 2) - tg = torch.jit.script(trivial_graph) - o_ref = tg(s, s, s) - tg_a = StaticRuntime(tg) - o_test = tg_a(s, s, s)[0] - torch.testing.assert_allclose(o_ref, o_test) - - # Arguments taken from benchmark script, ./bench/dlrm_s_benchmark.sh - ln_bot = [512, 512, 64] - sigmoid_bot = -1 - ln_top = [100, 1024, 1024, 1024, 1] - sigmoid_top = 3 - bot_l = create_mlp(ln_bot, sigmoid_bot) - bot_l_acc = StaticRuntime(bot_l) - top_l = create_mlp(ln_top, sigmoid_top) - top_l_acc = StaticRuntime(top_l) - bot_inp = torch.randn(2048, 512) # torch.Size([2048, 512]) - top_inp = torch.randn(2048, 100) # torch.Size([2048, 100]) - ref_bot = bot_l(bot_inp) - acc_bot = bot_l_acc(bot_inp)[0] - torch.testing.assert_allclose(acc_bot, ref_bot) - ref_top = top_l(top_inp) - acc_top = top_l_acc(top_inp)[0] - torch.testing.assert_allclose(acc_top, ref_top) + run_tests() diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py index 046b8a26ca283f6..9945f604799b954 100644 --- a/test/test_tensorexpr.py +++ b/test/test_tensorexpr.py @@ -1281,5 +1281,25 @@ def test(x): scripted(x) assert torch.equal(scripted(x), test(x)) + def test_simple_add(self): + val = torch._C._jit_get_te_generate_block_code() + torch._C._jit_set_te_generate_block_code(True) + fall_bk = torch._C._jit_texpr_fallback_allowed() + torch._C._jit_texpr_set_fallback_allowed(True) + + def simple(a, b): + return torch.add(a, b) + + a = torch.ones(256, 256) + b = torch.ones(256, 256) + traced = torch.jit.trace(simple, + (torch.ones(256, 256), torch.ones(256, 256))) + f = traced(a, b) + f_test = np.full((256, 256), 2, dtype=float) + np.testing.assert_allclose(f.numpy(), f_test) + torch._C._jit_set_te_generate_block_code(val) + torch._C._jit_texpr_set_fallback_allowed(fall_bk) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_torch.py b/test/test_torch.py index 1e6684296638bfb..ecc4650b28ddfb4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2105,10 +2105,6 @@ def test_masked_fill(self): val = random.random() dst2 = dst.clone() - if dt == torch.half: - self.assertRaises(RuntimeError, lambda: dst.masked_fill_(mask, val)) - continue - dst.masked_fill_(mask, val) for i in range(num_dest): if mask[i]: @@ -2122,7 +2118,7 @@ def test_masked_fill(self): dst2.masked_fill_((dst2 > 0).to(dtype), val) self.assertEqual(dst, dst2, atol=0, rtol=0) - self.assertEqual(len(w), 34) + self.assertEqual(len(w), 36) warn = 'masked_fill_ received a mask with dtype torch.uint8,' for wi in w: @@ -9801,6 +9797,33 @@ def test_diagflat(self, device): expected = torch.diag(x.contiguous().view(-1)) self.assertEqual(result, expected) + # Ensure that nuclear_norm's out variant gives the same result as the non-out + @onlyOnCPUAndCUDA + @skipCUDAIfNoMagma + @skipCPUIfNoLapack + @dtypes(torch.float32, torch.float64) + def test_nuclear_norm_out(self, device, dtype): + test_cases = [ + # input size, dim + ((25, 25), None), + ((25, 25), (0, 1)), + ((25, 25), (1, 0)), + ((25, 25, 25), (2, 0)), + ((25, 25, 25), (0, 1)), + ] + for keepdim in [False, True]: + for input_size, dim in test_cases: + msg = f'input_size: {input_size}, dim: {dim}, keepdim: {keepdim}' + x = torch.randn(*input_size, device=device, dtype=dtype) + result_out = torch.empty(0, device=device, dtype=dtype) + if dim is None: + result = torch.nuclear_norm(x, keepdim=keepdim) + torch.nuclear_norm(x, keepdim=keepdim, out=result_out) + else: + result = torch.nuclear_norm(x, keepdim=keepdim, dim=dim) + torch.nuclear_norm(x, keepdim=keepdim, dim=dim, out=result_out) + self.assertEqual(result, result_out, msg=msg) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @unittest.skipIf(not TEST_NUMPY, "Numpy not found") @@ -9986,13 +10009,14 @@ def check_single_nuclear_norm(x, axes): @skipCUDAIfNoMagma def test_nuclear_norm_exceptions(self, device): for lst in [], [1], [1, 2]: - for axes in (), (0,), (0, 1): - x = torch.tensor(lst, dtype=torch.double, device=device) + x = torch.tensor(lst, dtype=torch.double, device=device) + for axes in (), (0,): self.assertRaises(RuntimeError, torch.norm, x, "nuc", axes) + self.assertRaises(IndexError, torch.norm, x, "nuc", (0, 1)) x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.double, device=device) self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 0)) - self.assertRaisesRegex(RuntimeError, "duplicate or invalid", torch.norm, x, "nuc", (0, 2)) + self.assertRaisesRegex(IndexError, "Dimension out of range", torch.norm, x, "nuc", (0, 2)) def test_embedding_scalar_weight_error(self, device): indices = torch.rand(2, 2, device=device).long() @@ -10862,10 +10886,6 @@ def isBinary(t): torch.bernoulli(torch.rand_like(p), out=p) self.assertTrue(isBinary(p)) - p = torch.rand(5, dtype=dtype, device=device).expand(5, 5) - torch.bernoulli(torch.rand_like(p), out=p) - self.assertTrue(isBinary(p)) - # RngUniform not implemented for Integral type in XLA test @dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False))) @dtypesIfCPU(*(torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False, include_complex=False))) @@ -12239,10 +12259,6 @@ def test_logical(self, device): x = torch.tensor([1, 2, 3, 4], device=device, dtype=dt) b = torch.tensor([2], device=device, dtype=dt) - if dt == torch.half and device == 'cpu': - self.assertRaises(RuntimeError, lambda: x.lt(2)) - continue - if dt == torch.bool: # torch.bool is a special case and is being tested later # in this test @@ -12539,10 +12555,6 @@ def test_masked_select(self, device, dtype): src = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=dtype, device=device) mask = torch.randint(2, (num_src,), device=device, dtype=maskType) - if dtype == torch.half and torch.device(device).type == 'cpu': - self.assertRaises(RuntimeError, lambda: src.masked_select(mask)) - continue - with warnings.catch_warnings(record=True) as w: dst = src.masked_select(mask) if maskType is torch.uint8: @@ -13173,6 +13185,22 @@ def test_atanh_domain_float(self, device, dtype): self.assertEqual(torch.isinf(torch.atanh(sample)), inf_mask) self.assertEqual(torch.isinf(sample.atanh()), inf_mask) + def test_nullary_op_mem_overlap(self, device): + ops = ( + ("random_", ()), + ("uniform_", ()), + ("cauchy_", ()), + ("log_normal_", ()), + ("exponential_", ()), + ("geometric_", (0.5,)), + ("normal_", ()), + ) + + x = torch.rand((1, 3)).expand((3, 3)) + for op, args in ops: + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + getattr(x, op)(*args) + # TODO: run on non-native device types @dtypes(torch.double) def test_unary_out_op_mem_overlap(self, device, dtype): @@ -13277,7 +13305,33 @@ def test_binary_op_mem_overlap(self, device, dtype): ("div", True, True, 'cpu'), ("div", True, True, 'cuda'), ("pow", True, True, 'cpu'), - ("pow", True, True, 'cuda') + ("pow", True, True, 'cuda'), + ("fmod", True, True, 'cpu'), + ("fmod", True, True, 'cuda'), + ("atan2", True, True, 'cpu'), + ("atan2", True, True, 'cuda'), + ("hypot", True, True, 'cpu'), + ("hypot", True, True, 'cuda'), + ("nextafter", True, True, 'cpu'), + ("nextafter", True, True, 'cuda'), + ("le", True, True, 'cpu'), + ("le", True, True, 'cuda'), + ("lt", True, True, 'cpu'), + ("lt", True, True, 'cuda'), + ("ge", True, True, 'cpu'), + ("ge", True, True, 'cuda'), + ("gt", True, True, 'cpu'), + ("gt", True, True, 'cuda'), + ("eq", True, True, 'cpu'), + ("eq", True, True, 'cuda'), + ("ne", True, True, 'cpu'), + ("ne", True, True, 'cuda'), + ("logical_and", True, True, 'cpu'), + ("logical_and", True, True, 'cuda'), + ("logical_or", True, True, 'cpu'), + ("logical_or", True, True, 'cuda'), + ("logical_xor", True, True, 'cpu'), + ("logical_xor", True, True, 'cuda'), ] for (fn, has_input_output_mem_overlap_check, @@ -13336,6 +13390,21 @@ def test_pow_scalar_overloads_mem_overlap(self, device, dtype): self.unary_check_input_output_mem_overlap( doubles, sz, lambda input, out: torch.pow(42, input, out=out)) + def test_index_add_mem_overlap(self, device): + x = torch.rand((1,), device=device).expand((6,)) + y = torch.rand((6,), device=device) + ind = torch.tensor([0, 2, 3], device=device) + value = torch.rand((3,), device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x.index_add_(0, ind, value) + + def test_shift_mem_overlap(self, device): + x = torch.rand(3, device=device) + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x[:-1] <<= x[1:] + with self.assertRaisesRegex(RuntimeError, 'unsupported operation'): + x[:-1] >>= x[1:] + @unittest.skipIf(not TEST_NUMPY, 'Numpy not found') def test_int_pow(self, device): @@ -13448,9 +13517,17 @@ def test_var_mean_some_dims(self, device): @skipCUDAIfRocm def test_blas_empty(self, device): - def fn(torchfn, *args, **kwargs): - return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape - for shape in args), **kwargs) + def fn(torchfn, *args, test_out=False, **kwargs): + def call_torch_fn(*args, **kwargs): + return torchfn(*tuple(torch.randn(shape, device=device) if isinstance(shape, tuple) else shape + for shape in args), **kwargs) + result = call_torch_fn(*args, **kwargs) + if not test_out: + return result + else: + out = torch.full_like(result, math.nan) + out1 = call_torch_fn(*args, **kwargs, out=out) + return out # mm, addmm self.assertEqual((0, 0), fn(torch.mm, (0, 0), (0, 0)).shape) @@ -13458,18 +13535,24 @@ def fn(torchfn, *args, **kwargs): self.assertEqual((5, 0), fn(torch.mm, (5, 0), (0, 0)).shape) self.assertEqual((3, 0), fn(torch.mm, (3, 2), (2, 0)).shape) self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6))) + self.assertEqual(torch.zeros((5, 6), device=device), fn(torch.mm, (5, 0), (0, 6), test_out=True)) self.assertEqual((0, 0), fn(torch.addmm, (0, 0), (0, 0), (0, 0)).shape) - self.assertEqual((5, 6), fn(torch.addmm, (5, 6), (5, 0), (0, 6)).shape) self.assertEqual((0, 1), fn(torch.addmm, (1, ), (0, 17), (17, 1)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6))) + self.assertEqual(t, fn(torch.addmm, t, (5, 0), (0, 6), test_out=True)) # mv, addmv self.assertEqual((0,), fn(torch.mv, (0, 0), (0,)).shape) self.assertEqual((0,), fn(torch.mv, (0, 2), (2,)).shape) self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,))) + self.assertEqual(torch.zeros((3,), device=device), fn(torch.mv, (3, 0), (0,), test_out=True)) self.assertEqual((0,), fn(torch.addmv, (0,), (0, 0), (0,)).shape) - self.assertEqual((3,), fn(torch.addmv, (3,), (3, 0), (0,)).shape) + t = torch.randn((3,), device=device) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,))) + self.assertEqual(t, fn(torch.addmv, t, (3, 0), (0,), test_out=True)) # ger, addr self.assertEqual((0, 0), fn(torch.ger, (0,), (0,)).shape) @@ -13485,6 +13568,7 @@ def fn(torchfn, *args, **kwargs): self.assertEqual((3, 0, 5), fn(torch.bmm, (3, 0, 0), (3, 0, 5)).shape) self.assertEqual((0, 5, 6), fn(torch.bmm, (0, 5, 0), (0, 0, 6)).shape) self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6))) + self.assertEqual(torch.zeros((3, 5, 6), device=device), fn(torch.bmm, (3, 5, 0), (3, 0, 6), test_out=True)) self.assertEqual((0, 0, 0), fn(torch.baddbmm, (0, 0, 0), (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((3, 0, 5), fn(torch.baddbmm, (3, 0, 5), (3, 0, 0), (3, 0, 5)).shape) @@ -13492,21 +13576,27 @@ def fn(torchfn, *args, **kwargs): self.assertEqual((3, 5, 6), fn(torch.baddbmm, (3, 5, 6), (3, 5, 0), (3, 0, 6)).shape) c = torch.arange(30, dtype=torch.float32, device=device).reshape(3, 2, 5) self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2)) # Issue #33467 + self.assertEqual(-2 * c, fn(torch.baddbmm, c, (3, 2, 0), (3, 0, 5), beta=-2, test_out=True)) # Issue #33467 # addbmm self.assertEqual((0, 0), fn(torch.addbmm, (0, 0), (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((0, 5), fn(torch.addbmm, (0, 5), (3, 0, 0), (3, 0, 5)).shape) - self.assertEqual((5, 6), fn(torch.addbmm, (5, 6), (0, 5, 0), (0, 0, 6)).shape) + t = torch.randn((5, 6), device=device) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6))) + self.assertEqual(t, fn(torch.addbmm, t, (0, 5, 0), (0, 0, 6), test_out=True)) # matmul self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.matmul, (0,), (0,), test_out=True)) self.assertEqual((0, 0), fn(torch.matmul, (0, 0), (0, 0)).shape) self.assertEqual((0, 0, 0), fn(torch.matmul, (0, 0, 0), (0, 0, 0)).shape) self.assertEqual((5, 0, 0), fn(torch.matmul, (5, 0, 0), (5, 0, 0)).shape) self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4))) + self.assertEqual(torch.zeros((5, 3, 4), device=device), fn(torch.matmul, (5, 3, 0), (5, 0, 4), test_out=True)) # dot self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,))) + self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,), test_out=True)) if torch._C.has_lapack: # lu @@ -18093,6 +18183,92 @@ def test_atleast(self, device, dtype): self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False))) + def test_argminmax_multiple(self, device, dtype): + # Case: All Ones + t = torch.ones(3, 3, device=device, dtype=dtype) + self.compare_with_numpy(torch.argmax, np.argmax, t) + self.compare_with_numpy(torch.argmin, np.argmin, t) + + # Case: With single `nan` present. + if dtype in torch.testing.get_all_fp_dtypes(): + t[2, 2] = float('nan') + self.compare_with_numpy(torch.argmax, np.argmax, t) + self.compare_with_numpy(torch.argmin, np.argmin, t) + + # Case: Randomly Generated Tensors + for ndims in range(1, 5): + shape = self._rand_shape(ndims, min_size=5, max_size=10) + for with_extremal in [False, True]: + for contiguous in [False, True]: + # Generate Input. + x = self._generate_input(shape, dtype, device, with_extremal) + + if dtype == torch.half: + max_val = torch.max(x.to(torch.float)) + min_val = torch.min(x.to(torch.float)) + else: + max_val = torch.max(x) + min_val = torch.min(x) + + mask = torch.randn(x.shape) > 0.5 + x[mask] = torch.tensor(max_val + 1, dtype=dtype) + + mask = torch.randn(x.shape) > 0.5 + x[mask] = torch.tensor(min_val - 1, dtype=dtype) + + if not contiguous: + x = x.T + + self.compare_with_numpy(torch.argmax, np.argmax, x, device=None, dtype=None) + self.compare_with_numpy(torch.argmin, np.argmin, x, device=None, dtype=None) + + # Verify indices returned by max and min. + if dtype != torch.half: + rand_dim = random.randint(0, ndims - 1) + self.compare_with_numpy(lambda x: torch.max(x, dim=rand_dim)[1], + lambda x: np.argmax(x, axis=rand_dim), x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.min(x, dim=rand_dim)[1], + lambda x: np.argmin(x, axis=rand_dim), x, device=None, dtype=None) + + def verify_against_numpy(t): + # Argmax + torch_fn = partial(torch.argmax, dim=1) + np_fn = partial(np.argmax, axis=1) + self.compare_with_numpy(torch_fn, np_fn, t) + # Non-contiguous input + self.compare_with_numpy(torch_fn, np_fn, t.T) + + # Verify indices returned by max. + if dtype != torch.half: + self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.max(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) + + # Argmin + torch_fn = partial(torch.argmin, dim=1) + np_fn = partial(np.argmin, axis=1) + self.compare_with_numpy(torch_fn, np_fn, t) + # Non-contiguous input + self.compare_with_numpy(torch_fn, np_fn, t.T) + + # Verify indices returned by min. + if dtype != torch.half: + self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x, device=None, dtype=None) + self.compare_with_numpy(lambda x: torch.min(x, dim=1)[1], np_fn, x.T, device=None, dtype=None) + + # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 + t = torch.tensor([[1, 5], + [2, 10], + [3, 3]], device=device, dtype=dtype) + verify_against_numpy(t) + + # Case: Sample from issue: https://github.com/pytorch/pytorch/issues/41998 + t = torch.tensor([[1, 5], + [2, 10], + [0, 0]], device=device, dtype=dtype) + verify_against_numpy(t) + def _test_special_stacks(self, dim, at_least_dim, torch_fn, np_fn, device, dtype): # Test error for non-tuple argument with self.assertRaisesRegex(TypeError, "must be tuple of Tensors, not Tensor"): @@ -18581,7 +18757,7 @@ def fn(contiguous_input=True, dim0=0, dim1=1): self.assertEqual(res.shape, torch.Size([0])) @onlyOnCPUAndCUDA - @dtypes(*torch.testing.get_all_complex_dtypes()) + @dtypes(*torch.testing.get_all_complex_dtypes(include_complex32=True)) def test_view_as_real(self, device, dtype): def fn(contiguous_input=True): t = torch.randn(3, 4, dtype=dtype, device=device) @@ -18589,7 +18765,11 @@ def fn(contiguous_input=True): res = torch.view_as_real(input) self.assertEqual(res[:, :, 0], input.real) self.assertEqual(res[:, :, 1], input.imag) - self.assertTrue(self.is_view_of(t, res)) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(t, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(t, res)) fn() fn(contiguous_input=False) @@ -18597,13 +18777,21 @@ def fn(contiguous_input=True): # tensor with zero elements x = torch.tensor([], dtype=dtype, device=device) res = torch.view_as_real(x) - self.assertTrue(self.is_view_of(x, res)) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(x, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) self.assertEqual(res.shape, torch.Size([0, 2])) # tensor with zero dim x = torch.tensor(2 + 3j, dtype=dtype, device=device) res = torch.view_as_real(x) - self.assertTrue(self.is_view_of(x, res)) + # TODO: Add torch.ComplexHalfStorage + if dtype != torch.complex32: + self.assertTrue(self.is_view_of(x, res)) + else: + self.assertRaises(RuntimeError, lambda: self.is_view_of(x, res)) self.assertEqual(res.shape, torch.Size([2])) @onlyOnCPUAndCUDA diff --git a/test/test_xnnpack_integration.py b/test/test_xnnpack_integration.py index 42ae19aa4a6453c..ee0962137d4099d 100644 --- a/test/test_xnnpack_integration.py +++ b/test/test_xnnpack_integration.py @@ -93,6 +93,72 @@ def test_conv2d(self, xnnpack_result = torch.ops.prepacked.conv2d_clamp_run(input_data, packed_weight_bias) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + @given(batch_size=st.integers(1, 3), + input_channels_per_group=st.integers(1, 32), + height=st.integers(5, 64), + width=st.integers(5, 64), + output_channels_per_group=st.integers(1, 32), + groups=st.integers(1, 16), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + output_pad_h=st.integers(0, 2), + output_pad_w=st.integers(0, 2), + dilation=st.integers(1, 2), + use_bias=st.booleans(), + format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last])) + def test_conv2d_transpose(self, + batch_size, + input_channels_per_group, + height, + width, + output_channels_per_group, + groups, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + output_pad_h, + output_pad_w, + dilation, + use_bias, + format): + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_h, kernel_w) + strides = (stride_h, stride_w) + paddings = (pad_h, pad_w) + output_paddings = (output_pad_h, output_pad_w) + dilations = (dilation, dilation) + assume(height + 2 * paddings[0] + >= dilations[0] * (kernels[0] - 1) + 1) + assume(width + 2 * paddings[1] + >= dilations[1] * (kernels[1] - 1) + 1) + assume((output_pad_h < stride_h) and (output_pad_h < dilation)) + assume((output_pad_w < stride_w) and (output_pad_w < dilation)) + + input_data = torch.rand((batch_size, input_channels, height, width)) + if (format is not None): + input_data = input_data.contiguous(memory_format=format) + weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w)) + bias = None + if use_bias: + bias = torch.rand((output_channels)) + + # Note that groups/dilation is in reverse order from conv2d + ref_result = F.conv_transpose2d(input_data, weight, bias, + strides, paddings, output_paddings, groups, dilation) + packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias, + strides, paddings, + output_paddings, dilations, + groups) + xnnpack_result = torch.ops.prepacked.conv2d_transpose_clamp_run(input_data, packed_weight_bias) + torch.testing.assert_allclose(ref_result.contiguous(), xnnpack_result.contiguous(), rtol=1e-2, atol=1e-3) @unittest.skipUnless(torch.backends.xnnpack.enabled, " XNNPACK must be enabled for these tests." @@ -244,6 +310,114 @@ def forward(self, x): xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + @given(batch_size=st.integers(0, 3), + input_channels_per_group=st.integers(1, 32), + height=st.integers(5, 64), + width=st.integers(5, 64), + output_channels_per_group=st.integers(1, 32), + groups=st.integers(1, 16), + kernel_h=st.integers(1, 7), + kernel_w=st.integers(1, 7), + stride_h=st.integers(1, 2), + stride_w=st.integers(1, 2), + pad_h=st.integers(0, 2), + pad_w=st.integers(0, 2), + output_pad_h=st.integers(0, 2), + output_pad_w=st.integers(0, 2), + dilation=st.integers(1, 2), + use_bias=st.booleans(), + format=st.sampled_from([None, torch.preserve_format, torch.contiguous_format, torch.channels_last])) + def test_conv2d_transpose(self, + batch_size, + input_channels_per_group, + height, + width, + output_channels_per_group, + groups, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + output_pad_h, + output_pad_w, + dilation, + use_bias, + format): + class Conv2DT(torch.nn.Module): + def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups): + super(Conv2DT, self).__init__() + self.weight = weight + self.bias = bias + self.strides = strides + self.paddings = paddings + self.output_paddings = output_paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv_transpose2d(x, self.weight, self.bias, + self.strides, self.paddings, self.output_paddings, self.groups, self.dilations) + + class Conv2DTPrePacked(torch.nn.Module): + def __init__(self, weight, bias, strides, paddings, output_paddings, dilations, groups): + super(Conv2DTPrePacked, self).__init__() + self.packed_weight_bias = torch.ops.prepacked.conv2d_transpose_clamp_prepack(weight, bias, + strides, paddings, + output_paddings, + dilations, groups) + + def forward(self, x): + return torch.ops.prepacked.conv2d_transpose_clamp_run(x, self.packed_weight_bias) + + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_h, kernel_w) + strides = (stride_h, stride_w) + paddings = (pad_h, pad_w) + output_paddings = (output_pad_h, output_pad_w) + dilations = (dilation, dilation) + assume(height + 2 * paddings[0] >= + dilations[0] * (kernels[0] - 1) + 1) + assume(width + 2 * paddings[1] >= + dilations[1] * (kernels[1] - 1) + 1) + assume((output_pad_h < stride_h) and (output_pad_h < dilation)) + assume((output_pad_w < stride_w) and (output_pad_w < dilation)) + + input_data = torch.rand((batch_size, input_channels, height, width)) + if (format is not None): + input_data = input_data.contiguous(memory_format=format) + weight = torch.rand((input_channels, output_channels_per_group, kernel_h, kernel_w)) + bias = None + if use_bias: + bias = torch.rand((output_channels)) + + scripted_conv2d = torch.jit.script(Conv2DT(weight, bias, + strides, paddings, + output_paddings, dilations, groups)) + scripted_conv2d_clamp_prepacked = torch.jit.script(Conv2DTPrePacked( + weight, bias, strides, paddings, output_paddings, dilations, groups)) + ref_result = scripted_conv2d(input_data) + xnnpack_result = scripted_conv2d_clamp_prepacked(input_data) + torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + + # Serialize the modules and then deserialize + input_data = torch.rand((batch_size, input_channels, height, width)) + if (format is not None): + input_data = input_data.contiguous(memory_format=format) + buffer = io.BytesIO() + torch.jit.save(scripted_conv2d, buffer) + buffer.seek(0) + deserialized_conv2d = torch.jit.load(buffer) + buffer = io.BytesIO() + torch.jit.save(scripted_conv2d_clamp_prepacked, buffer) + buffer.seek(0) + deserialized_conv2d_clamp_prepacked = torch.jit.load(buffer) + ref_result = deserialized_conv2d(input_data) + xnnpack_result = deserialized_conv2d_clamp_prepacked(input_data) + torch.testing.assert_allclose(ref_result, xnnpack_result, rtol=1e-2, atol=1e-3) + @given(batch_size=st.integers(0, 3), input_channels_per_group=st.integers(1, 32), height=st.integers(5, 64), @@ -454,14 +628,17 @@ def forward(self, x): kernel_h = kernel_w = 3 stride_h = stride_w = 1 pad_h = pad_w = 1 + output_pad_h = output_pad_w = 0 dilation = 1 input_channels = input_channels_per_group * groups output_channels = output_channels_per_group * groups kernels = (kernel_h, kernel_w) strides = (stride_h, stride_w) paddings = (pad_h, pad_w) + output_paddings = (output_pad_h, output_pad_w) dilations = (dilation, dilation) conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) + conv_transpose_weight_shape = (input_channels, output_channels_per_group, kernel_h, kernel_w) conv_bias_shape = (output_channels) class Conv2D(torch.nn.Module): @@ -478,12 +655,34 @@ def forward(self, x): return F.conv2d(x, self.weight, self.bias, self.strides, self.paddings, self.dilations, self.groups) + class Conv2DT(torch.nn.Module): + def __init__(self): + super(Conv2DT, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_transpose_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.output_paddings = output_paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv_transpose2d(x, self.weight, self.bias, + self.strides, self.paddings, self.output_paddings, self.groups, self.dilations) + + data_shape = (batch_size, input_channels, height, width) pattern_count_map = {"Tensor = aten::conv2d": -1, "prepacked::conv2d_clamp_prepack": 1, "prepacked::conv2d_clamp_run": 1} TestXNNPACKRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) + transpose_data_shape = (batch_size, input_channels, height, width) + transpose_pattern_count_map = {"Tensor = aten::conv_transpose2d": -1, + "prepacked::conv2d_transpose_clamp_prepack": 1, + "prepacked::conv2d_transpose_clamp_run": 1} + TestXNNPACKRewritePass.validate_transformed_module(Conv2DT(), transpose_pattern_count_map, data_shape) + input_data = torch.rand((batch_size, input_channels, height, width)) conv_weight = torch.rand((output_channels, input_channels_per_group, kernel_h, kernel_w)) conv_bias = torch.rand((output_channels)) diff --git a/third_party/substitution.bzl b/third_party/substitution.bzl index bcc24cae70879c6..2a288cd813d81d6 100644 --- a/third_party/substitution.bzl +++ b/third_party/substitution.bzl @@ -42,3 +42,39 @@ template_rule = rule( output_to_genfiles = True, implementation = template_rule_impl, ) + +# Header template rule is an extension of template substitution rule +# That also makes this header a valid dependency for cc_library +# From https://stackoverflow.com/a/55407399 +def header_template_rule_impl(ctx): + ctx.actions.expand_template( + template = ctx.file.src, + output = ctx.outputs.out, + substitutions = ctx.attr.substitutions, + ) + return [ + # create a provider which says that this + # out file should be made available as a header + CcInfo(compilation_context=cc_common.create_compilation_context( + + # pass out the include path for finding this header + includes=depset([ctx.outputs.out.dirname, ctx.bin_dir.path]), + + # and the actual header here. + headers=depset([ctx.outputs.out]) + )) + ] + +header_template_rule = rule( + attrs = { + "src": attr.label( + mandatory = True, + allow_single_file = True, + ), + "out": attr.output(mandatory = True), + "substitutions": attr.string_dict(mandatory = True), + }, + # output_to_genfiles is required for header files. + output_to_genfiles = True, + implementation = header_template_rule_impl, +) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5846b940e569cf2..bc73ec3456783e8 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -393,9 +393,12 @@ self: norm_backward(grad, self - other, p, result) other: -norm_backward(grad, self - other, p, result) +# The backward formula is done in this order to improve numerical stability +# of the higher order derivatives, see https://github.com/pytorch/pytorch/issues/43414 +# Note that we don't use "result" because saving it would be BC-breaking when it is used in an inplace operation later - name: div.Tensor(Tensor self, Tensor other) -> Tensor self: grad / other - other: -grad * self / (other * other) + other: -grad * (self / other) / other - name: div.Scalar(Tensor self, Scalar other) -> Tensor self: grad / other diff --git a/tools/autograd/templates/python_linalg_functions.cpp b/tools/autograd/templates/python_linalg_functions.cpp index fa139eef0b875ed..b02438e31189c1e 100644 --- a/tools/autograd/templates/python_linalg_functions.cpp +++ b/tools/autograd/templates/python_linalg_functions.cpp @@ -12,6 +12,7 @@ using at::Tensor; using at::Scalar; +using at::ScalarType; using at::MemoryFormat; using at::Generator; using at::IntArrayRef; diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index ac8e7f929866b9a..3c8642de1a7448b 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -91,9 +91,7 @@ core_sources_common = [ ] jit_sources_common = [ - "torch/csrc/jit/runtime/register_prim_ops.cpp", "torch/csrc/jit/runtime/register_prim_ops_c10.cpp", - "torch/csrc/jit/runtime/register_special_ops.cpp", ] libtorch_sources_common = core_sources_common + jit_sources_common @@ -241,6 +239,7 @@ core_sources_full = [ "torch/csrc/jit/tensorexpr/kernel.cpp", "torch/csrc/jit/tensorexpr/llvm_codegen.cpp", "torch/csrc/jit/tensorexpr/llvm_jit.cpp", + "torch/csrc/jit/tensorexpr/block_codegen.cpp", "torch/csrc/jit/tensorexpr/loopnest.cpp", "torch/csrc/jit/tensorexpr/mem_arena.cpp", "torch/csrc/jit/tensorexpr/registerizer.cpp", @@ -296,7 +295,9 @@ jit_sources_full = [ "torch/csrc/jit/codegen/cuda/interface.cpp", "torch/csrc/jit/passes/lower_graph.cpp", "torch/csrc/jit/runtime/register_c10_ops.cpp", + "torch/csrc/jit/runtime/register_prim_ops.cpp", "torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp", + "torch/csrc/jit/runtime/register_special_ops.cpp", "torch/csrc/jit/runtime/register_string_ops.cpp", "torch/csrc/jit/passes/inline_fork_wait.cpp", "torch/csrc/jit/passes/remove_inplace_ops.cpp", diff --git a/tools/code_analyzer/build.sh b/tools/code_analyzer/build.sh index 3cf5999da97ddfe..1081087e4d4f624 100755 --- a/tools/code_analyzer/build.sh +++ b/tools/code_analyzer/build.sh @@ -62,7 +62,6 @@ build_torch_mobile() { BUILD_ROOT="${TORCH_BUILD_ROOT}" "${SRC_ROOT}/scripts/build_mobile.sh" \ -DCMAKE_CXX_FLAGS="-S -emit-llvm -DSTRIP_ERROR_MESSAGES" \ - -DUSE_STATIC_DISPATCH=OFF \ ${MOBILE_BUILD_FLAGS} } diff --git a/tools/code_analyzer/default_op_deps.yaml b/tools/code_analyzer/default_op_deps.yaml new file mode 100644 index 000000000000000..98b289aa632b266 --- /dev/null +++ b/tools/code_analyzer/default_op_deps.yaml @@ -0,0 +1,10823 @@ +- name: __ROOT__ + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::_indices + - name: aten::_sparse_coo_tensor_unsafe + - name: aten::_values + - name: aten::_version + - name: aten::add + - name: aten::add_ + - name: aten::any + - name: aten::as_strided_ + - name: aten::cat + - name: aten::chunk + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::dequantize + - name: aten::detach + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::equal + - name: aten::expand + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::isnan + - name: aten::item + - name: aten::lt + - name: aten::mm + - name: aten::mul + - name: aten::narrow + - name: aten::ones_like + - name: aten::output_nr + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::rsqrt + - name: aten::scalar_tensor + - name: aten::select + - name: aten::set_ + - name: aten::set_data + - name: aten::size + - name: aten::stride + - name: aten::sub + - name: aten::sum + - name: aten::t + - name: aten::to + - name: aten::view + - name: aten::zero_ + - name: aten::zeros + - name: aten::zeros_like +- name: _quantized::add + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::conv2d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::conv2d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: _quantized::conv2d_relu + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear_dynamic + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear_prepack_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear_prepack_fp16_legacy + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _quantized::linear_prepack_legacy + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _test::cat + depends: + - name: aten::cat + - name: aten::eq + - name: aten::is_nonzero +- name: _test::get_first + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: _test::leaky_relu + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::leaky_relu +- name: aten::Int + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item +- name: aten::__and__ + depends: + - name: aten::bitwise_and + - name: aten::eq + - name: aten::is_nonzero +- name: aten::__iand__ + depends: + - name: aten::bitwise_and_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::__ilshift__ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::__ior__ + depends: + - name: aten::bitwise_or_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::__irshift__ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::__ixor__ + depends: + - name: aten::bitwise_xor_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::__lshift__ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::__or__ + depends: + - name: aten::bitwise_or + - name: aten::eq + - name: aten::is_nonzero +- name: aten::__rshift__ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::__xor__ + depends: + - name: aten::bitwise_xor + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_adaptive_avg_pool2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size + - name: aten::stride +- name: aten::_adaptive_avg_pool2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::zeros_like +- name: aten::_add_batch_dim + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_addmv_impl_ + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::_addr + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::zero_ +- name: aten::_addr_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::zero_ +- name: aten::_amp_non_finite_check_and_unscale_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_amp_update_scale + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_baddbmm_mkl_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_batch_norm_impl_index + depends: + - name: aten::contiguous + - name: aten::cudnn_batch_norm + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::miopen_batch_norm + - name: aten::native_batch_norm + - name: aten::size +- name: aten::_batch_norm_impl_index_backward + depends: + - name: aten::cudnn_batch_norm_backward + - name: aten::eq + - name: aten::is_nonzero + - name: aten::miopen_batch_norm_backward + - name: aten::native_batch_norm_backward +- name: aten::_bmm + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cast_Byte + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Char + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Double + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Float + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Half + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Int + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Long + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cast_Short + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_cat + depends: + - name: aten::_cat + - name: aten::_copy_from + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::cat + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_cdist_backward + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::size + - name: aten::view +- name: aten::_cdist_forward + depends: + - name: aten::_euclidean_dist + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::size + - name: aten::view + - name: aten::zeros +- name: aten::_cholesky_helper + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_cholesky_solve_helper + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_choose_qparams_per_tensor + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min +- name: aten::_coalesced_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_compute_linear_combination + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros +- name: aten::_convolution + depends: + - name: aten::_convolution + - name: aten::_convolution_nogroup + - name: aten::_unsafe_view + - name: aten::add_ + - name: aten::cat + - name: aten::contiguous + - name: aten::convolution_overrideable + - name: aten::copy_ + - name: aten::cudnn_convolution + - name: aten::cudnn_convolution_transpose + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::miopen_convolution + - name: aten::miopen_convolution_transpose + - name: aten::miopen_depthwise_convolution + - name: aten::mul + - name: aten::narrow + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::slow_conv3d + - name: aten::squeeze + - name: aten::thnn_conv_depthwise2d + - name: aten::to + - name: aten::unsqueeze + - name: aten::view +- name: aten::_convolution_double_backward + depends: + - name: aten::_convolution + - name: aten::add + - name: aten::cat + - name: aten::contiguous + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::size + - name: aten::transpose + - name: aten::view +- name: aten::_convolution_nogroup + depends: + - name: aten::_convolution_nogroup + - name: aten::_nnpack_available + - name: aten::_nnpack_spatial_convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::slow_conv3d + - name: aten::slow_conv_dilated2d + - name: aten::slow_conv_dilated3d + - name: aten::slow_conv_transpose2d + - name: aten::slow_conv_transpose3d + - name: aten::thnn_conv2d + - name: aten::to +- name: aten::_copy_from + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_ctc_loss + depends: + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::permute + - name: aten::size + - name: aten::stride +- name: aten::_ctc_loss_backward + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::full_like + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::permute + - name: aten::size + - name: aten::stride + - name: aten::zero_ +- name: aten::_cudnn_ctc_loss + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cudnn_init_dropout_state + depends: + - name: aten::_cudnn_init_dropout_state + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cudnn_rnn + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cudnn_rnn_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cudnn_rnn_flatten_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cufft_clear_plan_cache + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cufft_get_plan_cache_max_size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cufft_get_plan_cache_size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cufft_set_plan_cache_max_size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_cummax_helper + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::_cummin_helper + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::_cumprod + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_cumsum + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_debug_has_internal_overlap + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_dimI + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_dimV + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_dim_arange + depends: + - name: aten::arange + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_dirichlet_grad + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_embedding_bag + depends: + - name: aten::copy_ + - name: aten::cumsum + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::expand_as + - name: aten::fill_ + - name: aten::index_add_ + - name: aten::is_nonzero + - name: aten::max + - name: aten::ones_like + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::slice + - name: aten::stride + - name: aten::sub + - name: aten::sub_ + - name: aten::to + - name: aten::unsqueeze + - name: aten::zero_ + - name: aten::zeros +- name: aten::_embedding_bag_backward + depends: + - name: aten::_embedding_bag_dense_backward + - name: aten::_embedding_bag_sparse_backward + - name: aten::cumsum + - name: aten::eq + - name: aten::index_add_ + - name: aten::is_nonzero + - name: aten::ones_like + - name: aten::resize_ + - name: aten::select + - name: aten::sub_ + - name: aten::zeros +- name: aten::_embedding_bag_dense_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::index_add_ + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::nonzero + - name: aten::select + - name: aten::size + - name: aten::sort + - name: aten::stride + - name: aten::view + - name: aten::zeros +- name: aten::_embedding_bag_forward_only + depends: + - name: aten::copy_ + - name: aten::cumsum + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::expand_as + - name: aten::fill_ + - name: aten::index_add_ + - name: aten::is_nonzero + - name: aten::max + - name: aten::ones_like + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::slice + - name: aten::stride + - name: aten::sub + - name: aten::sub_ + - name: aten::to + - name: aten::unsqueeze + - name: aten::zero_ + - name: aten::zeros +- name: aten::_embedding_bag_per_sample_weights_backward + depends: + - name: aten::cumsum + - name: aten::eq + - name: aten::index_add_ + - name: aten::is_nonzero + - name: aten::ones_like + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::sub_ + - name: aten::zeros +- name: aten::_embedding_bag_sparse_backward + depends: + - name: aten::div_ + - name: aten::embedding_dense_backward + - name: aten::embedding_sparse_backward + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::size + - name: aten::to + - name: aten::unsqueeze +- name: aten::_empty_affine_quantized + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_empty_per_channel_affine_quantized + depends: + - name: aten::_empty_per_channel_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::size + - name: aten::to +- name: aten::_euclidean_dist + depends: + - name: aten::cat + - name: aten::clamp_min_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::matmul + - name: aten::mul + - name: aten::ones_like + - name: aten::pow + - name: aten::sqrt_ + - name: aten::sum + - name: aten::transpose +- name: aten::_fake_quantize_learnable_per_channel_affine + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to + - name: aten::view +- name: aten::_fake_quantize_learnable_per_channel_affine_backward + depends: + - name: aten::as_strided_ + - name: aten::clamp + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::sum + - name: aten::to +- name: aten::_fake_quantize_learnable_per_tensor_affine + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::to +- name: aten::_fake_quantize_learnable_per_tensor_affine_backward + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::select + - name: aten::sum + - name: aten::to + - name: aten::unsqueeze +- name: aten::_fft_with_size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_foreach_add + depends: + - name: aten::add + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_fused_dropout + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_gather_sparse_backward + depends: + - name: aten::_sparse_coo_tensor_unsafe + - name: aten::arange + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::repeat + - name: aten::reshape + - name: aten::select + - name: aten::size + - name: aten::unsqueeze + - name: aten::view +- name: aten::_grid_sampler_2d_cpu_fallback + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::_grid_sampler_2d_cpu_fallback_backward + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::_has_compatible_shallow_copy_type + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_index_copy_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_index_put_impl_ + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::nonzero + - name: aten::permute + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_indices + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_inverse_helper + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_local_scalar_dense + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_log_softmax + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::size + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::view +- name: aten::_log_softmax_backward_data + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::size + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::view +- name: aten::_logcumsumexp + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::_lu_solve_helper + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zeros_like +- name: aten::_lu_with_info + depends: + - name: aten::clone + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::size + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zeros +- name: aten::_make_per_channel_quantized_tensor + depends: + - name: aten::_empty_per_channel_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_make_per_tensor_quantized_tensor + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_masked_scale + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_min_max + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::_mkldnn_reshape + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_mkldnn_transpose + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_mkldnn_transpose_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_mode + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::set_ +- name: aten::_multinomial_alias_draw + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_multinomial_alias_setup + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::set_ +- name: aten::_nnpack_available + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_nnpack_spatial_convolution + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::zeros +- name: aten::_nnpack_spatial_convolution_backward + depends: + - name: aten::_nnpack_spatial_convolution_backward_input + - name: aten::_nnpack_spatial_convolution_backward_weight + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::sum + - name: aten::view +- name: aten::_nnpack_spatial_convolution_backward_input + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_nnpack_spatial_convolution_backward_weight + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_nnz + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_pack_padded_sequence + depends: + - name: aten::cat + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::slice + - name: aten::transpose + - name: aten::view +- name: aten::_pack_padded_sequence_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::select + - name: aten::size + - name: aten::slice + - name: aten::transpose + - name: aten::zeros +- name: aten::_pad_packed_sequence + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::full + - name: aten::is_nonzero + - name: aten::size + - name: aten::slice + - name: aten::transpose + - name: aten::view +- name: aten::_pdist_backward + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::_pdist_forward + depends: + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::_qr_helper + depends: + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand_as + - name: aten::eye + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::size +- name: aten::_remove_batch_dim + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::permute +- name: aten::_reshape_from_tensor + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape +- name: aten::_s_where + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::_sample_dirichlet + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::sum + - name: aten::zeros +- name: aten::_saturate_weight_to_fp16 + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_shape_as_tensor + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_sobol_engine_draw + depends: + - name: aten::clone + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::stride +- name: aten::_sobol_engine_ff_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::stride +- name: aten::_sobol_engine_initialize_state_ + depends: + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::pow + - name: aten::select +- name: aten::_sobol_engine_scramble_ + depends: + - name: aten::arange + - name: aten::clone + - name: aten::diagonal + - name: aten::eq + - name: aten::expand_as + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul + - name: aten::pow + - name: aten::size + - name: aten::sum +- name: aten::_softmax + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::size + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::view +- name: aten::_softmax_backward_data + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::size + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::view +- name: aten::_solve_helper + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_sparse_addmm + depends: + - name: aten::addmm + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_sparse_coo_tensor_unsafe + depends: + - name: aten::_sparse_coo_tensor_unsafe + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::size +- name: aten::_sparse_coo_tensor_with_dims + depends: + - name: aten::_sparse_coo_tensor_with_dims + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_sparse_coo_tensor_with_dims_and_tensors + depends: + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_sparse_log_softmax + depends: + - name: aten::_sparse_log_softmax + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_sparse_log_softmax_backward_data + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_sparse_mm + depends: + - name: aten::_sparse_addmm + - name: aten::eq + - name: aten::is_nonzero + - name: aten::zeros +- name: aten::_sparse_softmax + depends: + - name: aten::_sparse_softmax + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::_sparse_softmax_backward_data + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_sparse_sum + depends: + - name: aten::_indices + - name: aten::_nnz + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::_sparse_sum + - name: aten::_values + - name: aten::clone + - name: aten::coalesce + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::sum + - name: aten::to + - name: aten::values +- name: aten::_sparse_sum_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_standard_gamma + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::zeros +- name: aten::_standard_gamma_grad + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_std + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::mean + - name: aten::scalar_tensor +- name: aten::_svd_helper + depends: + - name: aten::clone + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zero_ +- name: aten::_symeig_helper + depends: + - name: aten::clone + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_test_optional_filled_intlist + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_test_optional_floatlist + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_test_optional_intlist + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::_test_serialization_subcmul + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sub +- name: aten::_thnn_differentiable_gru_cell_backward + depends: + - name: aten::add + - name: aten::cat + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sigmoid + - name: aten::sigmoid_backward + - name: aten::sub + - name: aten::sub_ + - name: aten::sum + - name: aten::tanh + - name: aten::tanh_backward + - name: aten::unsafe_chunk +- name: aten::_thnn_differentiable_lstm_cell_backward + depends: + - name: aten::add + - name: aten::cat + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sigmoid + - name: aten::sigmoid_backward + - name: aten::sum + - name: aten::tanh + - name: aten::tanh_backward + - name: aten::unsafe_chunk + - name: aten::zeros_like +- name: aten::_thnn_fused_gru_cell + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_thnn_fused_gru_cell_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_thnn_fused_lstm_cell + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_thnn_fused_lstm_cell_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_triangular_solve_helper + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::transpose + - name: aten::transpose_ +- name: aten::_trilinear + depends: + - name: aten::add_ + - name: aten::bmm + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::narrow + - name: aten::permute + - name: aten::reshape + - name: aten::size + - name: aten::squeeze_ + - name: aten::sum + - name: aten::unsqueeze + - name: aten::view + - name: aten::zeros +- name: aten::_unique + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ +- name: aten::_unique2 + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ +- name: aten::_unsafe_view + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::view +- name: aten::_use_cudnn_ctc_loss + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_use_cudnn_rnn_flatten_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_validate_sparse_coo_tensor_args + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::max + - name: aten::min + - name: aten::size + - name: aten::to +- name: aten::_values + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_var + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::mean + - name: aten::scalar_tensor +- name: aten::_version + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_weight_norm + depends: + - name: aten::_weight_norm_cuda_interface + - name: aten::contiguous + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::norm_except_dim +- name: aten::_weight_norm_cuda_interface + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_weight_norm_cuda_interface_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::_weight_norm_differentiable_backward + depends: + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::size + - name: aten::sub + - name: aten::sum + - name: aten::to + - name: aten::view +- name: aten::abs + depends: + - name: aten::abs + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::real + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::abs_ + depends: + - name: aten::abs + - name: aten::eq + - name: aten::is_nonzero +- name: aten::absolute + depends: + - name: aten::abs + - name: aten::eq + - name: aten::is_nonzero +- name: aten::absolute_ + depends: + - name: aten::abs_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::acos + depends: + - name: aten::acos + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::acos_ + depends: + - name: aten::acos + - name: aten::eq + - name: aten::is_nonzero +- name: aten::acosh + depends: + - name: aten::acosh + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::acosh_ + depends: + - name: aten::acosh + - name: aten::eq + - name: aten::is_nonzero +- name: aten::adaptive_avg_pool1d + depends: + - name: aten::adaptive_avg_pool2d + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::adaptive_avg_pool2d + depends: + - name: aten::_adaptive_avg_pool2d + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mean + - name: aten::mkldnn_adaptive_avg_pool2d + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::view +- name: aten::adaptive_avg_pool3d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size + - name: aten::stride +- name: aten::adaptive_avg_pool3d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::adaptive_max_pool1d + depends: + - name: aten::adaptive_max_pool2d + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::adaptive_max_pool2d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride +- name: aten::adaptive_max_pool2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::adaptive_max_pool3d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride +- name: aten::adaptive_max_pool3d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::add + depends: + - name: aten::add + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_meta + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::permute + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::size + - name: aten::to + - name: aten::view +- name: aten::add_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::add_relu + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::add_relu_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::addbmm + depends: + - name: aten::addbmm + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ +- name: aten::addbmm_ + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::transpose + - name: aten::transpose_ +- name: aten::addcdiv + depends: + - name: aten::addcdiv + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::addcdiv_ + depends: + - name: aten::addcdiv + - name: aten::eq + - name: aten::is_nonzero +- name: aten::addcmul + depends: + - name: aten::addcmul + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::addcmul_ + depends: + - name: aten::addcmul + - name: aten::eq + - name: aten::is_nonzero +- name: aten::addmm + depends: + - name: aten::addmm + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ +- name: aten::addmm_ + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::transpose + - name: aten::transpose_ +- name: aten::addmv + depends: + - name: aten::_addmv_impl_ + - name: aten::_copy_from + - name: aten::addmv + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::zero_ +- name: aten::addmv_ + depends: + - name: aten::_addmv_impl_ + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::zero_ +- name: aten::addr + depends: + - name: aten::_addr + - name: aten::addr + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::size + - name: aten::to +- name: aten::addr_ + depends: + - name: aten::_addr_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::affine_grid_generator + depends: + - name: aten::bmm + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::linspace + - name: aten::mul + - name: aten::select + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze_ + - name: aten::view +- name: aten::affine_grid_generator_backward + depends: + - name: aten::bmm + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::linspace + - name: aten::mul + - name: aten::select + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze_ + - name: aten::view +- name: aten::alias + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::align_as + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::is_nonzero +- name: aten::align_tensors + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rename + - name: aten::view +- name: aten::align_to + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::is_nonzero +- name: aten::all + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::squeeze_ + - name: aten::to + - name: aten::unsqueeze +- name: aten::allclose + depends: + - name: aten::all + - name: aten::eq + - name: aten::is_nonzero + - name: aten::isclose + - name: aten::item +- name: aten::alpha_dropout + depends: + - name: aten::add + - name: aten::add_ + - name: aten::bernoulli_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::mul_ + - name: aten::zeros +- name: aten::alpha_dropout_ + depends: + - name: aten::add + - name: aten::add_ + - name: aten::bernoulli_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::zeros +- name: aten::angle + depends: + - name: aten::angle + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::real + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::any + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::squeeze_ + - name: aten::to + - name: aten::unsqueeze +- name: aten::arange + depends: + - name: aten::arange + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::arccos + depends: + - name: aten::acos + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::arccos_ + depends: + - name: aten::acos + - name: aten::eq + - name: aten::is_nonzero +- name: aten::arccosh + depends: + - name: aten::acosh + - name: aten::eq + - name: aten::is_nonzero +- name: aten::arccosh_ + depends: + - name: aten::acosh_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::arcsin + depends: + - name: aten::as_strided_ + - name: aten::asin + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::arcsin_ + depends: + - name: aten::asin + - name: aten::eq + - name: aten::is_nonzero +- name: aten::arctan + depends: + - name: aten::as_strided_ + - name: aten::atan + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::arctan_ + depends: + - name: aten::atan + - name: aten::eq + - name: aten::is_nonzero +- name: aten::argmax + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::resize_ + - name: aten::to + - name: aten::zeros +- name: aten::argmin + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::resize_ + - name: aten::to + - name: aten::zeros +- name: aten::argsort + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sort +- name: aten::as_strided + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::as_strided_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::asin + depends: + - name: aten::as_strided_ + - name: aten::asin + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::asin_ + depends: + - name: aten::asin + - name: aten::eq + - name: aten::is_nonzero +- name: aten::asinh + depends: + - name: aten::as_strided_ + - name: aten::asinh + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::asinh_ + depends: + - name: aten::asinh + - name: aten::eq + - name: aten::is_nonzero +- name: aten::atan + depends: + - name: aten::as_strided_ + - name: aten::atan + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::atan2 + depends: + - name: aten::as_strided_ + - name: aten::atan2 + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::atan2_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::atan_ + depends: + - name: aten::atan + - name: aten::eq + - name: aten::is_nonzero +- name: aten::atanh + depends: + - name: aten::as_strided_ + - name: aten::atanh + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::atanh_ + depends: + - name: aten::atanh + - name: aten::eq + - name: aten::is_nonzero +- name: aten::atleast_1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape +- name: aten::atleast_2d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::unsqueeze +- name: aten::atleast_3d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::unsqueeze +- name: aten::avg_pool1d + depends: + - name: aten::avg_pool2d + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::avg_pool2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size +- name: aten::avg_pool2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::avg_pool3d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size +- name: aten::avg_pool3d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::baddbmm + depends: + - name: aten::addmm_ + - name: aten::baddbmm + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zero_ +- name: aten::baddbmm_ + depends: + - name: aten::addmm_ + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zero_ +- name: aten::bartlett_window + depends: + - name: aten::add_ + - name: aten::arange + - name: aten::bartlett_window + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::narrow +- name: aten::batch_norm + depends: + - name: aten::_batch_norm_impl_index + - name: aten::add + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::select +- name: aten::batch_norm_backward_elemt + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_backward_reduce + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_elemt + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_gather_stats + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_gather_stats_with_counts + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_stats + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::batch_norm_update_stats + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::to +- name: aten::bernoulli + depends: + - name: aten::as_strided_ + - name: aten::bernoulli_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::bernoulli_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::bilinear + depends: + - name: aten::_trilinear + - name: aten::add + - name: aten::bilinear + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::to + - name: aten::view +- name: aten::binary_cross_entropy + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mean + - name: aten::mul_ + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::squeeze + - name: aten::sum + - name: aten::to +- name: aten::binary_cross_entropy_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::squeeze + - name: aten::to +- name: aten::binary_cross_entropy_with_logits + depends: + - name: aten::add_ + - name: aten::binary_cross_entropy_with_logits + - name: aten::clamp_min_ + - name: aten::empty_like + - name: aten::eq + - name: aten::exp_ + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log_ + - name: aten::mean + - name: aten::mul + - name: aten::mul_ + - name: aten::neg + - name: aten::sub + - name: aten::sub_ + - name: aten::sum + - name: aten::to +- name: aten::binary_cross_entropy_with_logits_backward + depends: + - name: aten::add + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::mul_ + - name: aten::sigmoid + - name: aten::sub + - name: aten::sub_ +- name: aten::bincount + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max + - name: aten::min + - name: aten::size + - name: aten::to + - name: aten::zero_ +- name: aten::binomial + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::zeros +- name: aten::bitwise_and + depends: + - name: aten::as_strided_ + - name: aten::bitwise_and + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::bitwise_and_ + depends: + - name: aten::bitwise_and + - name: aten::eq + - name: aten::is_nonzero +- name: aten::bitwise_not + depends: + - name: aten::as_strided_ + - name: aten::bitwise_not + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::bitwise_not_ + depends: + - name: aten::bitwise_not + - name: aten::eq + - name: aten::is_nonzero +- name: aten::bitwise_or + depends: + - name: aten::as_strided_ + - name: aten::bitwise_or + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::bitwise_or_ + depends: + - name: aten::bitwise_or + - name: aten::eq + - name: aten::is_nonzero +- name: aten::bitwise_xor + depends: + - name: aten::as_strided_ + - name: aten::bitwise_xor + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::bitwise_xor_ + depends: + - name: aten::bitwise_xor + - name: aten::eq + - name: aten::is_nonzero +- name: aten::blackman_window + depends: + - name: aten::add + - name: aten::arange + - name: aten::blackman_window + - name: aten::cos_ + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul + - name: aten::mul_ + - name: aten::narrow + - name: aten::sub +- name: aten::block_diag + depends: + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::size + - name: aten::slice + - name: aten::zeros +- name: aten::bmm + depends: + - name: aten::addmm_ + - name: aten::bmm + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ + - name: aten::zero_ +- name: aten::broadcast_tensors + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero +- name: aten::bucketize + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::can_cast + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::capitalize + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cartesian_prod + depends: + - name: aten::eq + - name: aten::flatten + - name: aten::is_nonzero + - name: aten::meshgrid + - name: aten::stack +- name: aten::cat + depends: + - name: aten::_cat + - name: aten::_indices + - name: aten::_nnz + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::_values + - name: aten::add_ + - name: aten::cat + - name: aten::dense_dim + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::max + - name: aten::min + - name: aten::narrow + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::to + - name: aten::zero_ +- name: aten::cauchy_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::cdist + depends: + - name: aten::_cdist_forward + - name: aten::_euclidean_dist + - name: aten::cdist + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::size + - name: aten::to + - name: aten::view + - name: aten::zeros +- name: aten::ceil + depends: + - name: aten::as_strided_ + - name: aten::ceil + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::ceil_ + depends: + - name: aten::ceil + - name: aten::eq + - name: aten::is_nonzero +- name: aten::celu + depends: + - name: aten::elu + - name: aten::eq + - name: aten::is_nonzero +- name: aten::celu_ + depends: + - name: aten::elu_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::center + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::chain_matmul + depends: + - name: aten::chain_matmul + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mm + - name: aten::size + - name: aten::to +- name: aten::channel_shuffle + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::reshape + - name: aten::size + - name: aten::view +- name: aten::cholesky + depends: + - name: aten::_cholesky_helper + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::tril_ + - name: aten::triu_ +- name: aten::cholesky_inverse + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::cholesky_solve + depends: + - name: aten::_cholesky_solve_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::chunk + depends: + - name: aten::chunk + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::size + - name: aten::split + - name: aten::split_with_sizes +- name: aten::clamp + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::clamp + - name: aten::clamp_max + - name: aten::clamp_min + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: aten::clamp_ + depends: + - name: aten::clamp + - name: aten::eq + - name: aten::is_nonzero +- name: aten::clamp_max + depends: + - name: aten::as_strided_ + - name: aten::clamp_max + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::clamp_max_ + depends: + - name: aten::clamp_max + - name: aten::eq + - name: aten::is_nonzero +- name: aten::clamp_min + depends: + - name: aten::as_strided_ + - name: aten::clamp_min + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::clamp_min_ + depends: + - name: aten::clamp_min + - name: aten::eq + - name: aten::is_nonzero +- name: aten::clip + depends: + - name: aten::clamp + - name: aten::eq + - name: aten::is_nonzero +- name: aten::clip_ + depends: + - name: aten::clamp_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::clone + depends: + - name: aten::_copy_from + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::coalesce + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::col2im + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::col2im_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::combinations + depends: + - name: aten::arange + - name: aten::eq + - name: aten::full + - name: aten::is_nonzero + - name: aten::le + - name: aten::lt + - name: aten::masked_select + - name: aten::meshgrid + - name: aten::mul_ + - name: aten::stack +- name: aten::complex + depends: + - name: aten::as_strided_ + - name: aten::complex + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::conj + depends: + - name: aten::as_strided_ + - name: aten::conj + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::constant_pad_nd + depends: + - name: aten::clone + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::size +- name: aten::contiguous + depends: + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero +- name: aten::conv1d + depends: + - name: aten::conv1d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::conv2d + depends: + - name: aten::conv2d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::conv3d + depends: + - name: aten::conv3d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::conv_tbc + depends: + - name: aten::addmm_ + - name: aten::conv_tbc + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::select + - name: aten::to + - name: aten::view +- name: aten::conv_tbc_backward + depends: + - name: aten::addmm_ + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::select + - name: aten::sum + - name: aten::t + - name: aten::view + - name: aten::zeros_like +- name: aten::conv_transpose1d + depends: + - name: aten::conv_transpose1d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::conv_transpose2d + depends: + - name: aten::conv_transpose2d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::conv_transpose3d + depends: + - name: aten::conv_transpose3d + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::convolution + depends: + - name: aten::_convolution + - name: aten::convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::convolution_backward_overrideable + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::convolution_overrideable + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::copy_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::copy_imag + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::copy_real + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::copy_sparse_to_sparse_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cos + depends: + - name: aten::as_strided_ + - name: aten::cos + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::cos_ + depends: + - name: aten::cos + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cosh + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::cosh + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::cosh_ + depends: + - name: aten::cosh + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cosine_embedding_loss + depends: + - name: aten::add + - name: aten::clamp_min_ + - name: aten::cosine_embedding_loss + - name: aten::div + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mean + - name: aten::mul + - name: aten::sqrt_ + - name: aten::sub + - name: aten::sub_ + - name: aten::sum + - name: aten::to + - name: aten::where + - name: aten::zeros_like +- name: aten::cosine_similarity + depends: + - name: aten::clamp_min_ + - name: aten::cosine_similarity + - name: aten::div_ + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sqrt_ + - name: aten::sum + - name: aten::to +- name: aten::count + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::count_nonzero + depends: + - name: aten::count_nonzero + - name: aten::eq + - name: aten::is_nonzero + - name: aten::ne + - name: aten::sum +- name: aten::cross + depends: + - name: aten::cross + - name: aten::empty_like + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::ctc_loss + depends: + - name: aten::_ctc_loss + - name: aten::_cudnn_ctc_loss + - name: aten::_use_cudnn_ctc_loss + - name: aten::clamp_min + - name: aten::contiguous + - name: aten::div + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mean + - name: aten::sum + - name: aten::to + - name: aten::where + - name: aten::zeros +- name: aten::cudnn_affine_grid_generator + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_affine_grid_generator_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_batch_norm + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_batch_norm_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution + depends: + - name: aten::cudnn_convolution + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::cudnn_convolution_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution_backward_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution_transpose + depends: + - name: aten::cudnn_convolution_transpose + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::cudnn_convolution_transpose_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution_transpose_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_convolution_transpose_backward_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_grid_sampler + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_grid_sampler_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cudnn_is_acceptable + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::cummax + depends: + - name: aten::_cummax_helper + - name: aten::cummax + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ +- name: aten::cummin + depends: + - name: aten::_cummin_helper + - name: aten::cummin + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ +- name: aten::cumprod + depends: + - name: aten::_cumprod + - name: aten::cumprod + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::to +- name: aten::cumsum + depends: + - name: aten::_cumsum + - name: aten::cumsum + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::to +- name: aten::data + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::deg2rad + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::deg2rad + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::mul + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::deg2rad_ + depends: + - name: aten::deg2rad + - name: aten::eq + - name: aten::is_nonzero +- name: aten::dense_dim + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::dequantize + depends: + - name: aten::dequantize + - name: aten::eq + - name: aten::is_nonzero +- name: aten::det + depends: + - name: aten::_lu_with_info + - name: aten::add_ + - name: aten::all + - name: aten::arange + - name: aten::contiguous + - name: aten::diagonal + - name: aten::eq + - name: aten::fmod_ + - name: aten::ge + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul_ + - name: aten::ne + - name: aten::prod + - name: aten::size + - name: aten::sum +- name: aten::detach + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::detach_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::diag + depends: + - name: aten::diag + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::zero_ +- name: aten::diag_embed + depends: + - name: aten::copy_ + - name: aten::diagonal + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::zeros +- name: aten::diagflat + depends: + - name: aten::contiguous + - name: aten::diag + - name: aten::eq + - name: aten::is_nonzero + - name: aten::view +- name: aten::diagonal + depends: + - name: aten::as_strided + - name: aten::diagonal + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::refine_names + - name: aten::size + - name: aten::stride +- name: aten::digamma + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::digamma + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::digamma_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::dist + depends: + - name: aten::dist + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::norm + - name: aten::sub + - name: aten::to +- name: aten::div + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::permute + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::view +- name: aten::div_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::dot + depends: + - name: aten::dot + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::stride + - name: aten::to +- name: aten::dropout + depends: + - name: aten::_fused_dropout + - name: aten::bernoulli_ + - name: aten::div_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::zeros +- name: aten::dropout_ + depends: + - name: aten::bernoulli_ + - name: aten::div_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::zeros +- name: aten::dstack + depends: + - name: aten::atleast_3d + - name: aten::cat + - name: aten::eq + - name: aten::is_nonzero +- name: aten::eig + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::einsum + depends: + - name: aten::bmm + - name: aten::diagonal + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::permute + - name: aten::reshape + - name: aten::size + - name: aten::sum + - name: aten::unsqueeze + - name: aten::view +- name: aten::elu + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::elu_ + depends: + - name: aten::elu + - name: aten::eq + - name: aten::is_nonzero +- name: aten::elu_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::embedding + depends: + - name: aten::eq + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::view +- name: aten::embedding_backward + depends: + - name: aten::embedding_dense_backward + - name: aten::embedding_sparse_backward + - name: aten::eq + - name: aten::is_nonzero +- name: aten::embedding_bag + depends: + - name: aten::_embedding_bag + - name: aten::_embedding_bag_forward_only + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero +- name: aten::embedding_dense_backward + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::select + - name: aten::size + - name: aten::view + - name: aten::zeros +- name: aten::embedding_renorm_ + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul_ + - name: aten::norm + - name: aten::select +- name: aten::embedding_sparse_backward + depends: + - name: aten::_sparse_coo_tensor_unsafe + - name: aten::empty + - name: aten::eq + - name: aten::index + - name: aten::is_nonzero + - name: aten::ne + - name: aten::reshape + - name: aten::size +- name: aten::empty + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::sparse_resize_and_clear_ +- name: aten::empty_like + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ +- name: aten::empty_meta + depends: + - name: aten::empty_meta + - name: aten::eq + - name: aten::is_nonzero +- name: aten::empty_quantized + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme +- name: aten::empty_strided + depends: + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero +- name: aten::endswith + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::eq + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::eq_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::equal + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::equal + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::is_same_size + - name: aten::resize_ + - name: aten::to +- name: aten::erf + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::erf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::erf_ + depends: + - name: aten::eq + - name: aten::erf + - name: aten::is_nonzero +- name: aten::erfc + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::erfc + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::erfc_ + depends: + - name: aten::eq + - name: aten::erfc + - name: aten::is_nonzero +- name: aten::erfinv + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::erfinv + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::erfinv_ + depends: + - name: aten::eq + - name: aten::erfinv + - name: aten::is_nonzero +- name: aten::exp + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::exp + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::exp_ + depends: + - name: aten::eq + - name: aten::exp + - name: aten::is_nonzero +- name: aten::expand + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::permute + - name: aten::view +- name: aten::expand_as + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero +- name: aten::expandtabs + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::expm1 + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expm1 + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::expm1_ + depends: + - name: aten::eq + - name: aten::expm1 + - name: aten::is_nonzero +- name: aten::exponential_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::eye + depends: + - name: aten::empty + - name: aten::eq + - name: aten::eye + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::zero_ +- name: aten::fake_quantize_per_channel_affine + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to + - name: aten::view +- name: aten::fake_quantize_per_channel_affine_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to + - name: aten::view +- name: aten::fake_quantize_per_tensor_affine + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::fake_quantize_per_tensor_affine_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::fbgemm_linear_fp16_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_linear_fp16_weight_fp32_activation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_linear_int8_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_linear_int8_weight_fp32_activation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_linear_quantize_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_pack_gemm_matrix_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fbgemm_pack_quantized_matrix + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::feature_alpha_dropout + depends: + - name: aten::add + - name: aten::add_ + - name: aten::bernoulli_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::mul_ + - name: aten::zeros +- name: aten::feature_alpha_dropout_ + depends: + - name: aten::add + - name: aten::add_ + - name: aten::bernoulli_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::zeros +- name: aten::feature_dropout + depends: + - name: aten::bernoulli_ + - name: aten::div_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::zeros +- name: aten::feature_dropout_ + depends: + - name: aten::bernoulli_ + - name: aten::div_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::zeros +- name: aten::fft + depends: + - name: aten::_fft_with_size + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::fft_fft + depends: + - name: aten::eq + - name: aten::fft + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::view_as_complex + - name: aten::view_as_real +- name: aten::fill_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::fill_diagonal_ + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::find + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::fix + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to + - name: aten::trunc +- name: aten::fix_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::trunc +- name: aten::flatten + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size +- name: aten::flip + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::fliplr + depends: + - name: aten::eq + - name: aten::flip + - name: aten::is_nonzero +- name: aten::flipud + depends: + - name: aten::eq + - name: aten::flip + - name: aten::is_nonzero +- name: aten::floor + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::floor + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to +- name: aten::floor_ + depends: + - name: aten::eq + - name: aten::floor + - name: aten::is_nonzero +- name: aten::floor_divide + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::floor_divide + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::trunc_ +- name: aten::floor_divide_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::floor_divide + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::trunc_ +- name: aten::fmod + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fmod + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::fmod_ + depends: + - name: aten::eq + - name: aten::fmod + - name: aten::is_nonzero +- name: aten::frac + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::frac + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::frac_ + depends: + - name: aten::eq + - name: aten::frac + - name: aten::is_nonzero +- name: aten::fractional_max_pool2d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::fractional_max_pool2d_backward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ +- name: aten::fractional_max_pool3d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::fractional_max_pool3d_backward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ +- name: aten::frobenius_norm + depends: + - name: aten::conj + - name: aten::eq + - name: aten::frobenius_norm + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mul + - name: aten::norm + - name: aten::real + - name: aten::sqrt + - name: aten::sum + - name: aten::to +- name: aten::from_file + depends: + - name: aten::eq + - name: aten::from_file + - name: aten::is_nonzero +- name: aten::full + depends: + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::full + - name: aten::is_nonzero + - name: aten::resize_ +- name: aten::full_like + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero +- name: aten::gather + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::gcd + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::gcd + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::gcd_ + depends: + - name: aten::eq + - name: aten::gcd + - name: aten::is_nonzero +- name: aten::ge + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::ge + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::ge_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::ge + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::gelu + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::as_strided_ + - name: aten::clone + - name: aten::copy_ + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::gelu + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::to +- name: aten::gelu_backward + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::as_strided_ + - name: aten::clone + - name: aten::copy_ + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::to +- name: aten::geometric_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::geqrf + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::ger + depends: + - name: aten::_addr + - name: aten::empty + - name: aten::eq + - name: aten::ger + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::get_gradients + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::glu + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::glu + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: aten::glu_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::glu_backward + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::narrow + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sigmoid + - name: aten::size + - name: aten::to +- name: aten::grid_sampler + depends: + - name: aten::cudnn_grid_sampler + - name: aten::eq + - name: aten::grid_sampler_2d + - name: aten::grid_sampler_3d + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::grid_sampler_2d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::grid_sampler_2d_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::grid_sampler_3d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::grid_sampler_3d_backward + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::group_norm + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::group_norm + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::native_group_norm + - name: aten::size + - name: aten::to +- name: aten::gru + depends: + - name: aten::_thnn_fused_gru_cell + - name: aten::add + - name: aten::add_ + - name: aten::cat + - name: aten::cudnn_is_acceptable + - name: aten::dropout + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::mul_ + - name: aten::narrow + - name: aten::sigmoid_ + - name: aten::size + - name: aten::stack + - name: aten::sub + - name: aten::t + - name: aten::tanh_ + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unbind + - name: aten::unsafe_chunk +- name: aten::gru_cell + depends: + - name: aten::_thnn_fused_gru_cell + - name: aten::add + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::mul_ + - name: aten::sigmoid_ + - name: aten::sub + - name: aten::t + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: aten::gt + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::gt + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::gt_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::gt + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::hamming_window + depends: + - name: aten::add_ + - name: aten::arange + - name: aten::cos_ + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::hamming_window + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::narrow +- name: aten::hann_window + depends: + - name: aten::add_ + - name: aten::arange + - name: aten::cos_ + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::hann_window + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::narrow +- name: aten::hardshrink + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardshrink_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardsigmoid + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: aten::hardsigmoid_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardsigmoid_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardswish + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardswish_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardswish_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hardtanh + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::clamp + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: aten::hardtanh_ + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::clamp_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: aten::hardtanh_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hinge_embedding_loss + depends: + - name: aten::add + - name: aten::clamp_min_ + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::hinge_embedding_loss + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mean + - name: aten::ne + - name: aten::sub_ + - name: aten::sum + - name: aten::to + - name: aten::where + - name: aten::zeros_like +- name: aten::histc + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::zero_ +- name: aten::hspmm + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::hstack + depends: + - name: aten::atleast_1d + - name: aten::cat + - name: aten::eq + - name: aten::is_nonzero +- name: aten::hypot + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::hypot_ + depends: + - name: aten::eq + - name: aten::hypot + - name: aten::is_nonzero +- name: aten::ifft + depends: + - name: aten::_fft_with_size + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::im2col + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::im2col_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::imag + depends: + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::select + - name: aten::view_as_real +- name: aten::index + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::nonzero + - name: aten::permute + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::index_add + depends: + - name: aten::clone + - name: aten::eq + - name: aten::index_add_ + - name: aten::is_nonzero +- name: aten::index_add_ + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::index_copy + depends: + - name: aten::clone + - name: aten::eq + - name: aten::index_copy_ + - name: aten::is_nonzero +- name: aten::index_copy_ + depends: + - name: aten::_index_copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: aten::index_fill + depends: + - name: aten::clone + - name: aten::eq + - name: aten::index_fill + - name: aten::index_fill_ + - name: aten::is_nonzero +- name: aten::index_fill_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::index_fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::index_put + depends: + - name: aten::clone + - name: aten::eq + - name: aten::index_put + - name: aten::index_put_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::to +- name: aten::index_put_ + depends: + - name: aten::_index_put_impl_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::index_select + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::indices + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::instance_norm + depends: + - name: aten::alias + - name: aten::batch_norm + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mean + - name: aten::repeat + - name: aten::size + - name: aten::view +- name: aten::int_repr + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::inverse + depends: + - name: aten::_inverse_helper + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::irfft + depends: + - name: aten::_fft_with_size + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::is_coalesced + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_complex + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_distributed + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_floating_point + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_leaf + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_nonzero + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item +- name: aten::is_pinned + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_same_size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_set_to + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::is_signed + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::is_vulkan_available + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::isclose + depends: + - name: aten::__iand__ + - name: aten::__ior__ + - name: aten::abs + - name: aten::add + - name: aten::eq + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::isfinite + - name: aten::le + - name: aten::mul + - name: aten::ne + - name: aten::sub + - name: aten::to +- name: aten::isfinite + depends: + - name: aten::abs + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::isfinite + - name: aten::mul + - name: aten::ne + - name: aten::ones_like +- name: aten::isidentifier + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::isinf + depends: + - name: aten::__ior__ + - name: aten::abs + - name: aten::eq + - name: aten::imag + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::isinf + - name: aten::real + - name: aten::zeros_like +- name: aten::islower + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::isnan + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::ne +- name: aten::isneginf + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::isneginf + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::isposinf + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::isposinf + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::isprintable + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::isreal + depends: + - name: aten::eq + - name: aten::imag + - name: aten::is_nonzero + - name: aten::ones_like +- name: aten::istft + depends: + - name: aten::_fft_with_size + - name: aten::abs + - name: aten::col2im + - name: aten::constant_pad_nd + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::min + - name: aten::mul + - name: aten::ones + - name: aten::pow + - name: aten::repeat + - name: aten::reshape + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view +- name: aten::istitle + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::isupper + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::item + depends: + - name: aten::_local_scalar_dense + - name: aten::_nnz + - name: aten::_values + - name: aten::dequantize + - name: aten::eq + - name: aten::is_coalesced + - name: aten::is_nonzero + - name: aten::item + - name: aten::sum +- name: aten::join + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::kl_div + depends: + - name: aten::eq + - name: aten::exp + - name: aten::gt + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::kl_div + - name: aten::log + - name: aten::mean + - name: aten::mul + - name: aten::sub + - name: aten::sum + - name: aten::to + - name: aten::where + - name: aten::zeros_like +- name: aten::kl_div_backward + depends: + - name: aten::div + - name: aten::eq + - name: aten::exp + - name: aten::expand_as + - name: aten::is_nonzero + - name: aten::mul + - name: aten::neg + - name: aten::zeros_like +- name: aten::kthvalue + depends: + - name: aten::clone + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::kthvalue + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::squeeze_ + - name: aten::unsqueeze_ + - name: aten::zero_ +- name: aten::l1_loss + depends: + - name: aten::abs_ + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::l1_loss + - name: aten::mean + - name: aten::sub + - name: aten::sum + - name: aten::to +- name: aten::l1_loss_backward + depends: + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::l1_loss_backward + - name: aten::mul_ + - name: aten::sign_ + - name: aten::sub + - name: aten::zeros_like +- name: aten::layer_norm + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::layer_norm + - name: aten::native_layer_norm + - name: aten::to +- name: aten::lcm + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::lcm + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::lcm_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::lcm +- name: aten::le + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::le + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::le_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::le + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::leaky_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::leaky_relu_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::leaky_relu + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::leaky_relu_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::lerp + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::lerp_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::lgamma + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::lgamma + - name: aten::resize_ + - name: aten::to +- name: aten::lgamma_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::lgamma +- name: aten::linalg_det + depends: + - name: aten::det + - name: aten::eq + - name: aten::is_nonzero +- name: aten::linear + depends: + - name: aten::add_ + - name: aten::addmm + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::mkldnn_linear + - name: aten::resize_ + - name: aten::size + - name: aten::t + - name: aten::to +- name: aten::linspace + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::linspace + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::ljust + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::log + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log + - name: aten::resize_ + - name: aten::to +- name: aten::log10 + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log10 + - name: aten::resize_ + - name: aten::to +- name: aten::log10_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::log10 +- name: aten::log1p + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log1p + - name: aten::resize_ + - name: aten::to +- name: aten::log1p_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::log1p +- name: aten::log2 + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log2 + - name: aten::resize_ + - name: aten::to +- name: aten::log2_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::log2 +- name: aten::log_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::log +- name: aten::log_normal_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::log_sigmoid + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::log_sigmoid_forward +- name: aten::log_sigmoid_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::log_sigmoid_forward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::zeros_like +- name: aten::log_softmax + depends: + - name: aten::_log_softmax + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::log_softmax + - name: aten::to +- name: aten::logaddexp + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logaddexp + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logaddexp2 + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logaddexp2 + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logcumsumexp + depends: + - name: aten::_logcumsumexp + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logcumsumexp + - name: aten::to +- name: aten::logdet + depends: + - name: aten::_lu_with_info + - name: aten::abs_ + - name: aten::add_ + - name: aten::all + - name: aten::arange + - name: aten::contiguous + - name: aten::diagonal + - name: aten::eq + - name: aten::fill_ + - name: aten::fmod_ + - name: aten::full + - name: aten::ge + - name: aten::index_put_ + - name: aten::is_nonzero + - name: aten::item + - name: aten::log_ + - name: aten::lt + - name: aten::mul_ + - name: aten::ne + - name: aten::nonzero_numpy + - name: aten::prod + - name: aten::sign + - name: aten::size + - name: aten::sum +- name: aten::logical_and + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::logical_and + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logical_and_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logical_and +- name: aten::logical_not + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logical_not + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logical_not_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logical_not +- name: aten::logical_or + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::logical_or + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logical_or_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logical_or +- name: aten::logical_xor + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::logical_xor + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logical_xor_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logical_xor +- name: aten::logit + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logit + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logit_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::logit +- name: aten::logit_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::logspace + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::logspace + - name: aten::resize_ +- name: aten::logsumexp + depends: + - name: aten::abs + - name: aten::add_ + - name: aten::empty + - name: aten::eq + - name: aten::exp + - name: aten::is_nonzero + - name: aten::log_ + - name: aten::logsumexp + - name: aten::masked_fill_ + - name: aten::max_values + - name: aten::squeeze + - name: aten::sub + - name: aten::sum +- name: aten::lstm + depends: + - name: aten::_thnn_fused_lstm_cell + - name: aten::add_ + - name: aten::cat + - name: aten::cudnn_is_acceptable + - name: aten::dropout + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::mul + - name: aten::narrow + - name: aten::sigmoid_ + - name: aten::size + - name: aten::stack + - name: aten::t + - name: aten::tanh + - name: aten::tanh_ + - name: aten::transpose + - name: aten::unbind + - name: aten::unsafe_chunk +- name: aten::lstm_cell + depends: + - name: aten::_thnn_fused_lstm_cell + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::mul + - name: aten::sigmoid_ + - name: aten::t + - name: aten::tanh + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: aten::lstrip + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::lstsq + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::unsqueeze +- name: aten::lt + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::lt + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::lt_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::lt + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::lu_solve + depends: + - name: aten::_lu_solve_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::lu_solve + - name: aten::resize_as_ + - name: aten::size +- name: aten::margin_ranking_loss + depends: + - name: aten::add + - name: aten::clamp_min_ + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::margin_ranking_loss + - name: aten::mean + - name: aten::mul + - name: aten::neg + - name: aten::sub + - name: aten::sum + - name: aten::to +- name: aten::masked_fill + depends: + - name: aten::clone + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::masked_fill_ +- name: aten::masked_fill_ + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::to +- name: aten::masked_scatter + depends: + - name: aten::clone + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::masked_scatter_ +- name: aten::masked_scatter_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::masked_select + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::item + - name: aten::resize_ + - name: aten::sum + - name: aten::to +- name: aten::matmul + depends: + - name: aten::_unsafe_view + - name: aten::bmm + - name: aten::contiguous + - name: aten::dot + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::matmul + - name: aten::mm + - name: aten::mv + - name: aten::reshape + - name: aten::resize_ + - name: aten::set_ + - name: aten::size + - name: aten::squeeze_ + - name: aten::t + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view +- name: aten::matrix_exp + depends: + - name: aten::_unsafe_view + - name: aten::abs + - name: aten::add_ + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::bmm + - name: aten::ceil + - name: aten::contiguous + - name: aten::copy_ + - name: aten::diag_embed + - name: aten::div + - name: aten::div_ + - name: aten::dot + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::exp + - name: aten::expand + - name: aten::fill_ + - name: aten::ge + - name: aten::gt + - name: aten::index_put_ + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::item + - name: aten::le + - name: aten::log2 + - name: aten::matmul + - name: aten::max + - name: aten::mm + - name: aten::mul + - name: aten::mv + - name: aten::narrow + - name: aten::nonzero + - name: aten::ones + - name: aten::pow + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::set_ + - name: aten::size + - name: aten::slice + - name: aten::squeeze + - name: aten::squeeze_ + - name: aten::stride + - name: aten::sum + - name: aten::t + - name: aten::to + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view + - name: aten::zeros + - name: aten::zeros_like +- name: aten::matrix_exp_backward + depends: + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::matrix_exp + - name: aten::narrow + - name: aten::size + - name: aten::transpose + - name: aten::zeros +- name: aten::matrix_power + depends: + - name: aten::_unsafe_view + - name: aten::bmm + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::dot + - name: aten::eq + - name: aten::expand + - name: aten::expand_as + - name: aten::eye + - name: aten::fill_ + - name: aten::inverse + - name: aten::is_nonzero + - name: aten::mm + - name: aten::mv + - name: aten::reshape + - name: aten::resize_ + - name: aten::set_ + - name: aten::size + - name: aten::squeeze_ + - name: aten::t + - name: aten::transpose + - name: aten::unsqueeze + - name: aten::view +- name: aten::matrix_rank + depends: + - name: aten::abs + - name: aten::eq + - name: aten::gt + - name: aten::is_nonzero + - name: aten::max + - name: aten::mul_ + - name: aten::size + - name: aten::sum + - name: aten::svd + - name: aten::symeig +- name: aten::max + depends: + - name: aten::_make_per_tensor_quantized_tensor + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::int_repr + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::max + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::squeeze_ + - name: aten::stride + - name: aten::to + - name: aten::unsqueeze_ +- name: aten::max_pool1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max_pool1d_with_indices +- name: aten::max_pool1d_with_indices + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max_pool2d_with_indices + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::max_pool2d + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max_pool2d_with_indices + - name: aten::mkldnn_max_pool2d + - name: aten::quantized_max_pool2d + - name: aten::resize_ + - name: aten::size +- name: aten::max_pool2d_with_indices + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::max_pool2d_with_indices_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::max_pool3d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max_pool3d_with_indices + - name: aten::mkldnn_max_pool3d +- name: aten::max_pool3d_with_indices + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::max_pool3d_with_indices_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::max_unpool2d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ +- name: aten::max_unpool2d_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ +- name: aten::max_unpool3d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ +- name: aten::max_unpool3d_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ +- name: aten::max_values + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::max + - name: aten::resize_ + - name: aten::select + - name: aten::to + - name: aten::unsqueeze +- name: aten::mean + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::dequantize + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mean + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::size + - name: aten::sum + - name: aten::to +- name: aten::median + depends: + - name: aten::clone + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::kthvalue + - name: aten::median + - name: aten::select + - name: aten::size + - name: aten::view +- name: aten::meshgrid + depends: + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::size + - name: aten::view +- name: aten::min + depends: + - name: aten::_make_per_tensor_quantized_tensor + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::int_repr + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::min + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::squeeze_ + - name: aten::stride + - name: aten::to + - name: aten::unsqueeze_ +- name: aten::min_values + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::min + - name: aten::resize_ + - name: aten::select + - name: aten::to + - name: aten::unsqueeze +- name: aten::miopen_batch_norm + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_batch_norm_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_backward_bias + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_backward_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_transpose + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_transpose_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_transpose_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_convolution_transpose_backward_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_depthwise_convolution + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_depthwise_convolution_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_depthwise_convolution_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_depthwise_convolution_backward_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_rnn + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::miopen_rnn_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_adaptive_avg_pool2d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_convolution + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_convolution_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_convolution_backward_input + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_convolution_backward_weights + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_linear + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_max_pool2d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_max_pool3d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_reorder_conv2d_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mkldnn_reorder_conv3d_weight + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::mm + depends: + - name: aten::clone + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::expand + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mm + - name: aten::to + - name: aten::transpose + - name: aten::transpose_ +- name: aten::mode + depends: + - name: aten::_mode + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mode + - name: aten::resize_ +- name: aten::movedim + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::movedim + - name: aten::permute +- name: aten::mse_loss + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mean + - name: aten::mse_loss + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sum + - name: aten::to +- name: aten::mse_loss_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mse_loss_backward + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to + - name: aten::zeros_like +- name: aten::mul + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::mul + - name: aten::permute + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::view +- name: aten::mul_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::multi_margin_loss + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::multi_margin_loss + - name: aten::resize_ + - name: aten::size + - name: aten::to +- name: aten::multi_margin_loss_backward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::multilabel_margin_loss + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::multilabel_margin_loss + - name: aten::multilabel_margin_loss_forward + - name: aten::to +- name: aten::multilabel_margin_loss_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::multilabel_margin_loss_forward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ +- name: aten::multinomial + depends: + - name: aten::bitwise_and + - name: aten::div_ + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::ge + - name: aten::is_nonzero + - name: aten::item + - name: aten::log_ + - name: aten::lt + - name: aten::max + - name: aten::min + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::sum + - name: aten::topk + - name: aten::uniform_ +- name: aten::mv + depends: + - name: aten::_addmv_impl_ + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mv + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to + - name: aten::zero_ +- name: aten::mvlgamma + depends: + - name: aten::add + - name: aten::add_ + - name: aten::all + - name: aten::arange + - name: aten::empty + - name: aten::eq + - name: aten::gt + - name: aten::is_nonzero + - name: aten::item + - name: aten::lgamma_ + - name: aten::sum + - name: aten::unsqueeze +- name: aten::mvlgamma_ + depends: + - name: aten::add + - name: aten::add_ + - name: aten::all + - name: aten::arange + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::gt + - name: aten::is_nonzero + - name: aten::item + - name: aten::lgamma_ + - name: aten::sum + - name: aten::unsqueeze +- name: aten::nansum + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::sum + - name: aten::to + - name: aten::zero_ +- name: aten::narrow + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::narrow + - name: aten::size + - name: aten::slice +- name: aten::narrow_copy + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow +- name: aten::native_batch_norm + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::native_batch_norm_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::to +- name: aten::native_group_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ +- name: aten::native_group_norm_backward + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ +- name: aten::native_layer_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::native_layer_norm + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::to +- name: aten::native_layer_norm_backward + depends: + - name: aten::_empty_affine_quantized + - name: aten::_empty_per_channel_affine_quantized + - name: aten::clone + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_per_channel_axis + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::zero_ +- name: aten::native_norm + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::ne + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::ne + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::ne_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::ne + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::neg + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::neg + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::neg_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::neg +- name: aten::negative + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::neg + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::negative_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::neg +- name: aten::new_empty + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero +- name: aten::new_full + depends: + - name: aten::eq + - name: aten::full + - name: aten::is_nonzero +- name: aten::new_zeros + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::zeros +- name: aten::nextafter + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::nextafter_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nextafter +- name: aten::nll_loss + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::nll_loss + - name: aten::nll_loss_forward + - name: aten::to +- name: aten::nll_loss2d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::nll_loss2d + - name: aten::nll_loss2d_forward + - name: aten::to +- name: aten::nll_loss2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::nll_loss2d_forward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::nll_loss_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::nll_loss_forward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::nonzero + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::nonzero_numpy + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::nonzero + - name: aten::unbind + - name: aten::unsqueeze +- name: aten::norm + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::native_norm + - name: aten::norm + - name: aten::resize_ + - name: aten::to + - name: aten::zero_ +- name: aten::norm_except_dim + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::norm + - name: aten::norm_except_dim + - name: aten::size + - name: aten::transpose + - name: aten::view +- name: aten::normal + depends: + - name: aten::add_ + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::full + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::normal + - name: aten::normal_ + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to + - name: aten::view_as_real +- name: aten::normal_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to + - name: aten::view_as_real +- name: aten::nuclear_norm + depends: + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::nuclear_norm + - name: aten::permute + - name: aten::sum + - name: aten::svd + - name: aten::to + - name: aten::unsqueeze_ +- name: aten::numpy_T + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute +- name: aten::one_hot + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::scatter_ + - name: aten::unsqueeze + - name: aten::zeros +- name: aten::ones + depends: + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::ones + - name: aten::resize_ +- name: aten::ones_like + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero +- name: aten::orgqr + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::ormqr + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::outer + depends: + - name: aten::eq + - name: aten::ger + - name: aten::is_nonzero +- name: aten::output_nr + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::pairwise_distance + depends: + - name: aten::add + - name: aten::eq + - name: aten::is_nonzero + - name: aten::norm + - name: aten::sub +- name: aten::partition + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::pdist + depends: + - name: aten::_pdist_forward + - name: aten::contiguous + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::pdist + - name: aten::to +- name: aten::permute + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute +- name: aten::pin_memory + depends: + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::is_pinned + - name: aten::set_ +- name: aten::pinverse + depends: + - name: aten::diag_embed + - name: aten::empty + - name: aten::eq + - name: aten::gt + - name: aten::is_nonzero + - name: aten::matmul + - name: aten::mul + - name: aten::narrow + - name: aten::reciprocal + - name: aten::svd + - name: aten::transpose + - name: aten::where + - name: aten::zeros +- name: aten::pixel_shuffle + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::reshape + - name: aten::size +- name: aten::poisson + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::zeros +- name: aten::poisson_nll_loss + depends: + - name: aten::add + - name: aten::add_ + - name: aten::eq + - name: aten::exp + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::le + - name: aten::log + - name: aten::masked_fill + - name: aten::mean + - name: aten::mul + - name: aten::poisson_nll_loss + - name: aten::sub + - name: aten::sum + - name: aten::to +- name: aten::polar + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::polar + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::polygamma + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::polygamma + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::polygamma_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::polygamma +- name: aten::pow + depends: + - name: aten::as_strided_ + - name: aten::can_cast + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::pow + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::result_type + - name: aten::scalar_tensor + - name: aten::to +- name: aten::pow_ + depends: + - name: aten::as_strided_ + - name: aten::can_cast + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::result_type + - name: aten::to +- name: aten::prelu + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::prelu + - name: aten::size + - name: aten::to +- name: aten::prelu_backward + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::sum +- name: aten::prod + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::prod + - name: aten::resize_ + - name: aten::select + - name: aten::to + - name: aten::unsqueeze +- name: aten::promote_types + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::put_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::q_per_channel_axis + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::q_per_channel_scales + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::q_per_channel_zero_points + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::q_scale + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::q_zero_point + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::qr + depends: + - name: aten::_qr_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ +- name: aten::qscheme + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::quantile + depends: + - name: aten::all + - name: aten::ceil + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::flatten + - name: aten::floor + - name: aten::ge + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::item + - name: aten::le + - name: aten::lerp_ + - name: aten::logical_and_ + - name: aten::mul + - name: aten::permute + - name: aten::quantile + - name: aten::reshape + - name: aten::scalar_tensor + - name: aten::size + - name: aten::sort + - name: aten::sub + - name: aten::to + - name: aten::unsqueeze_ +- name: aten::quantize_per_channel + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::size + - name: aten::to +- name: aten::quantize_per_tensor + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::quantize_per_tensor + - name: aten::select +- name: aten::quantized_batch_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: aten::quantized_gru + depends: + - name: aten::_thnn_fused_gru_cell + - name: aten::add + - name: aten::add_ + - name: aten::cat + - name: aten::dropout + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul_ + - name: aten::narrow + - name: aten::sigmoid_ + - name: aten::size + - name: aten::stack + - name: aten::sub + - name: aten::tanh_ + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unbind + - name: aten::unsafe_chunk +- name: aten::quantized_gru_cell + depends: + - name: aten::_thnn_fused_gru_cell + - name: aten::add + - name: aten::add_ + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::sigmoid_ + - name: aten::sub + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: aten::quantized_lstm + depends: + - name: aten::_thnn_fused_lstm_cell + - name: aten::add_ + - name: aten::cat + - name: aten::dropout + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::narrow + - name: aten::sigmoid_ + - name: aten::size + - name: aten::stack + - name: aten::tanh + - name: aten::tanh_ + - name: aten::transpose + - name: aten::unbind + - name: aten::unsafe_chunk +- name: aten::quantized_lstm_cell + depends: + - name: aten::_thnn_fused_lstm_cell + - name: aten::add_ + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sigmoid_ + - name: aten::tanh + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: aten::quantized_max_pool2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: aten::quantized_rnn_relu_cell + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::relu +- name: aten::quantized_rnn_tanh_cell + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero + - name: aten::tanh +- name: aten::rad2deg + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::mul + - name: aten::rad2deg + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::rad2deg_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rad2deg +- name: aten::rand + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rand + - name: aten::resize_ + - name: aten::uniform_ +- name: aten::rand_like + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::uniform_ +- name: aten::randint + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::randint + - name: aten::random_ + - name: aten::resize_ +- name: aten::randint_like + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::random_ +- name: aten::randn + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::normal_ + - name: aten::randn + - name: aten::resize_ +- name: aten::randn_like + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::normal_ +- name: aten::random_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::randperm + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::randperm + - name: aten::resize_ + - name: aten::scalar_tensor + - name: aten::stride +- name: aten::range + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::range + - name: aten::resize_ +- name: aten::real + depends: + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::select + - name: aten::view_as_real +- name: aten::reciprocal + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::reciprocal + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::reciprocal_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reciprocal +- name: aten::refine_names + depends: + - name: aten::alias + - name: aten::eq + - name: aten::is_nonzero +- name: aten::reflection_pad1d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::size +- name: aten::reflection_pad1d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::reflection_pad2d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::reflection_pad2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::relu + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::threshold + - name: aten::to +- name: aten::relu_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::threshold_ + - name: aten::to +- name: aten::remainder + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::remainder_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::rename + depends: + - name: aten::alias + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rename_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::renorm + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::renorm + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::renorm_ + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::repeat + depends: + - name: aten::alias + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_quantized + - name: aten::eq + - name: aten::expand + - name: aten::expand_as + - name: aten::is_nonzero + - name: aten::size + - name: aten::unfold +- name: aten::repeat_interleave + depends: + - name: aten::all + - name: aten::contiguous + - name: aten::cumsum + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::expand + - name: aten::flatten + - name: aten::ge + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::item + - name: aten::repeat_interleave + - name: aten::reshape + - name: aten::select + - name: aten::size + - name: aten::to +- name: aten::replace + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::replication_pad1d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::replication_pad1d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::replication_pad2d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::replication_pad2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::replication_pad3d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::replication_pad3d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros_like +- name: aten::requires_grad + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::requires_grad_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::reshape + depends: + - name: aten::_mkldnn_reshape + - name: aten::_unsafe_view + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::reshape + - name: aten::view +- name: aten::reshape_as + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape +- name: aten::resize_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::resize_as_ + depends: + - name: aten::dense_dim + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::sparse_dim +- name: aten::result_type + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::result_type + - name: aten::scalar_tensor + - name: aten::to +- name: aten::retain_grad + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rfft + depends: + - name: aten::_fft_with_size + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: aten::rfind + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rindex + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rjust + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rnn_relu + depends: + - name: aten::add_ + - name: aten::cat + - name: aten::cudnn_is_acceptable + - name: aten::dropout + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::narrow + - name: aten::relu + - name: aten::size + - name: aten::stack + - name: aten::t + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unbind +- name: aten::rnn_relu_cell + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::relu + - name: aten::t +- name: aten::rnn_tanh + depends: + - name: aten::add_ + - name: aten::cat + - name: aten::cudnn_is_acceptable + - name: aten::dropout + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::narrow + - name: aten::size + - name: aten::stack + - name: aten::t + - name: aten::tanh + - name: aten::transpose + - name: aten::transpose_ + - name: aten::unbind +- name: aten::rnn_tanh_cell + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::linear + - name: aten::matmul + - name: aten::t + - name: aten::tanh +- name: aten::roll + depends: + - name: aten::cat + - name: aten::clone + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::roll + - name: aten::size + - name: aten::view +- name: aten::rot90 + depends: + - name: aten::clone + - name: aten::eq + - name: aten::flip + - name: aten::is_nonzero + - name: aten::transpose_ +- name: aten::round + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::round + - name: aten::to +- name: aten::round_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::round +- name: aten::rpartition + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rrelu + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rrelu_with_noise +- name: aten::rrelu_ + depends: + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rrelu_with_noise_ +- name: aten::rrelu_with_noise + depends: + - name: aten::add + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::leaky_relu + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::rrelu_with_noise_ + depends: + - name: aten::add + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::leaky_relu + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::rrelu_with_noise_backward + depends: + - name: aten::add + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::div + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::item + - name: aten::leaky_relu_backward + - name: aten::mul + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::sub + - name: aten::to +- name: aten::rsplit + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rsqrt + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::rsqrt + - name: aten::to +- name: aten::rsqrt_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::rsqrt +- name: aten::rstrip + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::rsub + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::permute + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::rsub + - name: aten::scalar_tensor + - name: aten::to + - name: aten::view +- name: aten::scalar_tensor + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::scatter + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::scatter_ +- name: aten::scatter_ + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::scatter_add + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::scatter_add_ +- name: aten::scatter_add_ + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::searchsorted + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::select + depends: + - name: aten::_indices + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::_values + - name: aten::arange + - name: aten::as_strided + - name: aten::dense_dim + - name: aten::empty + - name: aten::eq + - name: aten::index_select + - name: aten::is_nonzero + - name: aten::ne + - name: aten::nonzero + - name: aten::permute + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::sum + - name: aten::view +- name: aten::selu + depends: + - name: aten::elu + - name: aten::eq + - name: aten::is_nonzero +- name: aten::selu_ + depends: + - name: aten::elu_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::set_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::set_ +- name: aten::set_data + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::set_quantizer_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sigmoid + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sigmoid + - name: aten::size + - name: aten::to +- name: aten::sigmoid_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sigmoid +- name: aten::sigmoid_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::sign + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sign + - name: aten::to +- name: aten::sign_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sign +- name: aten::signbit + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::signbit + - name: aten::to +- name: aten::silu + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::silu + - name: aten::to +- name: aten::silu_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::silu +- name: aten::silu_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::sin + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::sin + - name: aten::to +- name: aten::sin_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sin +- name: aten::sinh + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::sinh + - name: aten::to +- name: aten::sinh_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sinh +- name: aten::size + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::slice + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::slice +- name: aten::slogdet + depends: + - name: aten::_lu_with_info + - name: aten::abs_ + - name: aten::add_ + - name: aten::all + - name: aten::arange + - name: aten::contiguous + - name: aten::diagonal + - name: aten::eq + - name: aten::fmod_ + - name: aten::ge + - name: aten::is_nonzero + - name: aten::item + - name: aten::log_ + - name: aten::mul_ + - name: aten::ne + - name: aten::prod + - name: aten::sign + - name: aten::size + - name: aten::sum +- name: aten::slow_conv3d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::slow_conv3d_forward +- name: aten::slow_conv3d_backward + depends: + - name: aten::addmm_ + - name: aten::baddbmm_ + - name: aten::bmm + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mm + - name: aten::permute + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::transpose + - name: aten::view + - name: aten::zero_ +- name: aten::slow_conv3d_forward + depends: + - name: aten::addmm_ + - name: aten::baddbmm_ + - name: aten::bmm + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mm + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::unsqueeze + - name: aten::view +- name: aten::slow_conv_dilated2d + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::sum + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::slow_conv_dilated2d_backward + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::sum + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::slow_conv_dilated3d + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::sum + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::slow_conv_dilated3d_backward + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::sum + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::slow_conv_transpose2d + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::slow_conv_transpose2d_backward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::slow_conv_transpose3d + depends: + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::slow_conv_transpose3d_backward + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::zero_ +- name: aten::smm + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sspaddmm +- name: aten::smooth_l1_loss + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mean + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::smooth_l1_loss + - name: aten::sum + - name: aten::to +- name: aten::smooth_l1_loss_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::smooth_l1_loss_backward + - name: aten::to + - name: aten::zeros_like +- name: aten::soft_margin_loss + depends: + - name: aten::add_ + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::exp_ + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::log_ + - name: aten::mean + - name: aten::mul_ + - name: aten::neg + - name: aten::resize_ + - name: aten::soft_margin_loss + - name: aten::sum + - name: aten::to +- name: aten::soft_margin_loss_backward + depends: + - name: aten::add_ + - name: aten::div_ + - name: aten::empty + - name: aten::eq + - name: aten::exp + - name: aten::is_nonzero + - name: aten::mul + - name: aten::mul_ + - name: aten::neg + - name: aten::soft_margin_loss_backward +- name: aten::softmax + depends: + - name: aten::_softmax + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::softmax + - name: aten::to +- name: aten::softplus + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::softplus + - name: aten::to +- name: aten::softplus_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::softshrink + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::softshrink_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::solve + depends: + - name: aten::_solve_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::sort + depends: + - name: aten::_copy_from + - name: aten::_make_per_tensor_quantized_tensor + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::int_repr + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::sort + - name: aten::stride + - name: aten::to +- name: aten::sparse_coo_tensor + depends: + - name: aten::_sparse_coo_tensor_with_dims + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::add_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::max + - name: aten::min + - name: aten::size + - name: aten::sparse_coo_tensor + - name: aten::to +- name: aten::sparse_dim + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sparse_mask + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sparse_resize_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sparse_resize_and_clear_ + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::split + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::permute + - name: aten::size + - name: aten::split +- name: aten::split_with_sizes + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::permute + - name: aten::size + - name: aten::split_with_sizes +- name: aten::splitlines + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sqrt + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::sqrt + - name: aten::to +- name: aten::sqrt_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sqrt +- name: aten::square + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::pow +- name: aten::square_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::pow +- name: aten::squeeze + depends: + - name: aten::as_strided + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::permute + - name: aten::size + - name: aten::squeeze + - name: aten::to +- name: aten::squeeze_ + depends: + - name: aten::as_strided_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sspaddmm + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sspaddmm +- name: aten::stack + depends: + - name: aten::cat + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::stack + - name: aten::to + - name: aten::unsqueeze +- name: aten::startswith + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::std + depends: + - name: aten::_std + - name: aten::add + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::imag + - name: aten::is_nonzero + - name: aten::real + - name: aten::resize_ + - name: aten::scalar_tensor + - name: aten::sqrt + - name: aten::std + - name: aten::to +- name: aten::std_mean + depends: + - name: aten::add + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::imag + - name: aten::is_nonzero + - name: aten::mul + - name: aten::real + - name: aten::resize_ + - name: aten::sqrt + - name: aten::std_mean + - name: aten::to +- name: aten::stft + depends: + - name: aten::as_strided + - name: aten::copy_ + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::mul + - name: aten::narrow + - name: aten::rfft + - name: aten::size + - name: aten::squeeze_ + - name: aten::stride + - name: aten::transpose_ + - name: aten::unsqueeze + - name: aten::zeros +- name: aten::stride + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::strip + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::sub + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::permute + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::sub + - name: aten::to + - name: aten::view +- name: aten::sub_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to +- name: aten::sum + depends: + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_floating_point + - name: aten::is_nonzero + - name: aten::permute + - name: aten::resize_ + - name: aten::select + - name: aten::sum + - name: aten::to + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::sum_to_size + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sum + - name: aten::view +- name: aten::svd + depends: + - name: aten::_svd_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ +- name: aten::symeig + depends: + - name: aten::_symeig_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::t + depends: + - name: aten::dense_dim + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sparse_dim + - name: aten::transpose +- name: aten::t_ + depends: + - name: aten::dense_dim + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sparse_dim + - name: aten::transpose_ +- name: aten::take + depends: + - name: aten::_copy_from + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: aten::tan + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::tan + - name: aten::to +- name: aten::tan_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::tan +- name: aten::tanh + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::tanh + - name: aten::to +- name: aten::tanh_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::tanh +- name: aten::tanh_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::tensordot + depends: + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mm + - name: aten::permute + - name: aten::reshape + - name: aten::size + - name: aten::sum + - name: aten::tensordot + - name: aten::to +- name: aten::thnn_conv2d + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::thnn_conv2d_forward +- name: aten::thnn_conv2d_backward + depends: + - name: aten::addmm_ + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::reshape + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::select + - name: aten::size + - name: aten::transpose + - name: aten::view + - name: aten::zero_ +- name: aten::thnn_conv2d_forward + depends: + - name: aten::addmm_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mm + - name: aten::reshape + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::unsqueeze + - name: aten::view + - name: aten::zero_ +- name: aten::thnn_conv_depthwise2d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::thnn_conv_depthwise2d_forward +- name: aten::thnn_conv_depthwise2d_backward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::thnn_conv_depthwise2d_forward + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::threshold + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::threshold_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::threshold_backward + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: aten::title + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::to + depends: + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::to_dense + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::to_dense_backward + depends: + - name: aten::coalesce + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sparse_mask + - name: aten::to_mkldnn +- name: aten::to_mkldnn + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::to_mkldnn_backward + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to_dense +- name: aten::to_sparse + depends: + - name: aten::_coalesced_ + - name: aten::chunk + - name: aten::clone + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::index + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::nonzero + - name: aten::size + - name: aten::sparse_coo_tensor + - name: aten::squeeze + - name: aten::transpose + - name: aten::unique_dim + - name: aten::unsqueeze +- name: aten::topk + depends: + - name: aten::_empty_affine_quantized + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::topk + - name: aten::zero_ +- name: aten::trace + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::scalar_tensor +- name: aten::transpose + depends: + - name: aten::_coalesced_ + - name: aten::_indices + - name: aten::_mkldnn_transpose + - name: aten::_values + - name: aten::as_strided + - name: aten::clone + - name: aten::copy_ + - name: aten::dense_dim + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::transpose + - name: aten::zeros_like +- name: aten::transpose_ + depends: + - name: aten::_coalesced_ + - name: aten::_indices + - name: aten::_mkldnn_transpose_ + - name: aten::_values + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::dense_dim + - name: aten::eq + - name: aten::is_nonzero + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::zeros_like +- name: aten::trapz + depends: + - name: aten::add + - name: aten::div + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::select + - name: aten::size + - name: aten::slice + - name: aten::sub + - name: aten::sum + - name: aten::view + - name: aten::zeros +- name: aten::triangular_solve + depends: + - name: aten::_triangular_solve_helper + - name: aten::copy_ + - name: aten::eq + - name: aten::expand + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size +- name: aten::tril + depends: + - name: aten::as_strided + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::tril +- name: aten::tril_ + depends: + - name: aten::as_strided + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::tril_indices + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::tril_indices +- name: aten::triplet_margin_loss + depends: + - name: aten::add + - name: aten::clamp_min + - name: aten::eq + - name: aten::is_floating_point + - name: aten::is_leaf + - name: aten::is_nonzero + - name: aten::mean + - name: aten::min + - name: aten::pairwise_distance + - name: aten::sub + - name: aten::sum + - name: aten::to + - name: aten::triplet_margin_loss +- name: aten::triu + depends: + - name: aten::as_strided + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_as_ + - name: aten::size + - name: aten::stride + - name: aten::triu +- name: aten::triu_ + depends: + - name: aten::as_strided + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::stride +- name: aten::triu_indices + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::triu_indices +- name: aten::true_divide + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::true_divide +- name: aten::true_divide_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::true_divide_ +- name: aten::trunc + depends: + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::to + - name: aten::trunc +- name: aten::trunc_ + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::trunc +- name: aten::type_as + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::to +- name: aten::unbind + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::select + - name: aten::size + - name: aten::unbind +- name: aten::unflatten + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::view +- name: aten::unfold + depends: + - name: aten::as_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::size + - name: aten::stride + - name: aten::unfold +- name: aten::unfold_backward + depends: + - name: aten::arange + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::squeeze + - name: aten::stride + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros +- name: aten::uniform_ + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to + - name: aten::view_as_real +- name: aten::unique_consecutive + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::stack + - name: aten::transpose + - name: aten::unbind + - name: aten::view + - name: aten::zeros +- name: aten::unique_dim + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::select + - name: aten::size + - name: aten::stack + - name: aten::transpose + - name: aten::unbind + - name: aten::view + - name: aten::zeros +- name: aten::unique_dim_consecutive + depends: + - name: aten::add_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::select + - name: aten::size + - name: aten::stack + - name: aten::transpose + - name: aten::unbind + - name: aten::view + - name: aten::zeros +- name: aten::unsafe_chunk + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::unsafe_split + - name: aten::unsafe_split_with_sizes +- name: aten::unsafe_split + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::size +- name: aten::unsafe_split_with_sizes + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::narrow + - name: aten::size +- name: aten::unsqueeze + depends: + - name: aten::_cat + - name: aten::_indices + - name: aten::_nnz + - name: aten::_sparse_coo_tensor_with_dims_and_tensors + - name: aten::_values + - name: aten::add_ + - name: aten::as_strided + - name: aten::cat + - name: aten::contiguous + - name: aten::dense_dim + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::expand + - name: aten::is_nonzero + - name: aten::max + - name: aten::min + - name: aten::narrow + - name: aten::permute + - name: aten::select + - name: aten::size + - name: aten::sparse_dim + - name: aten::to + - name: aten::unsqueeze + - name: aten::zero_ +- name: aten::unsqueeze_ + depends: + - name: aten::as_strided_ + - name: aten::eq + - name: aten::is_nonzero +- name: aten::upsample_bicubic2d + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ +- name: aten::upsample_bicubic2d_backward + depends: + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_bilinear2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_bilinear2d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_linear1d + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_linear1d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_nearest1d + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_nearest1d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_nearest2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_nearest2d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_nearest3d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_nearest3d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::upsample_trilinear3d + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: aten::upsample_trilinear3d_backward + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size + - name: aten::zero_ + - name: aten::zeros +- name: aten::values + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: aten::vander + depends: + - name: aten::copy_ + - name: aten::cumprod + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::flip + - name: aten::is_nonzero + - name: aten::promote_types + - name: aten::select + - name: aten::size + - name: aten::slice + - name: aten::unsqueeze +- name: aten::var + depends: + - name: aten::_var + - name: aten::add + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::imag + - name: aten::is_nonzero + - name: aten::real + - name: aten::resize_ + - name: aten::scalar_tensor + - name: aten::sqrt + - name: aten::std + - name: aten::to + - name: aten::var +- name: aten::var_mean + depends: + - name: aten::add + - name: aten::as_strided + - name: aten::as_strided_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::fill_ + - name: aten::imag + - name: aten::is_nonzero + - name: aten::mul + - name: aten::real + - name: aten::resize_ + - name: aten::sqrt + - name: aten::to + - name: aten::var_mean +- name: aten::view + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::permute + - name: aten::view +- name: aten::view_as + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::view +- name: aten::view_as_complex + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::set_ +- name: aten::view_as_real + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::set_ +- name: aten::vstack + depends: + - name: aten::atleast_2d + - name: aten::cat + - name: aten::eq + - name: aten::is_nonzero +- name: aten::where + depends: + - name: aten::_s_where + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::expand + - name: aten::fill_ + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::nonzero_numpy + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::scalar_tensor + - name: aten::to + - name: aten::where +- name: aten::zero_ + depends: + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero +- name: aten::zeros + depends: + - name: aten::empty + - name: aten::eq + - name: aten::fill_ + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::sparse_resize_and_clear_ + - name: aten::zero_ + - name: aten::zeros +- name: aten::zeros_like + depends: + - name: aten::dense_dim + - name: aten::empty + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::sparse_dim + - name: aten::sparse_resize_and_clear_ + - name: aten::zero_ +- name: aten::zfill + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: prepacked::conv2d_clamp_prepack + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: prepacked::conv2d_clamp_run + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: prepacked::linear_clamp_prepack + depends: + - name: aten::contiguous + - name: aten::copy_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::resize_ + - name: aten::size +- name: prepacked::linear_clamp_run + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: profiler::_record_function_enter + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero +- name: profiler::_record_function_exit + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::add + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::add_out + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::add_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::to +- name: quantized::add_relu_out + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::add_scalar + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::add_scalar_out + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::add_scalar_relu + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::add_scalar_relu_out + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::batch_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: quantized::batch_norm1d + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: quantized::batch_norm1d_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: quantized::batch_norm2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: quantized::batch_norm2d_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: quantized::batch_norm3d + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: quantized::batch_norm3d_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: quantized::batch_norm_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::clone + - name: aten::contiguous + - name: aten::empty_like + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size + - name: aten::squeeze + - name: aten::unsqueeze +- name: quantized::cat + depends: + - name: aten::_empty_affine_quantized + - name: aten::cat + - name: aten::dequantize + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::size +- name: quantized::cat_out + depends: + - name: aten::_copy_from + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::cat + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: quantized::cat_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::cat + - name: aten::copy_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: quantized::cat_relu_out + depends: + - name: aten::_copy_from + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::cat + - name: aten::copy_ + - name: aten::copy_sparse_to_sparse_ + - name: aten::dequantize + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_complex + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::size + - name: aten::stride + - name: aten::to +- name: quantized::celu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::clamp + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: quantized::conv1d + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze_ + - name: aten::unsqueeze +- name: quantized::conv1d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::unsqueeze + - name: aten::zeros +- name: quantized::conv1d_relu + depends: + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze_ + - name: aten::unsqueeze +- name: quantized::conv1d_unpack + depends: + - name: aten::clone + - name: aten::eq + - name: aten::is_nonzero + - name: aten::squeeze_ +- name: quantized::conv2d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_dilation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_groups + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: quantized::conv2d_relu + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_stride + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv2d_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_dilation + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_groups + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_padding + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_relu + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_stride + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv3d_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::conv_prepack + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: quantized::conv_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::elu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::embedding_bag_2bit_prepack + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_2bit_unpack + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_4bit_prepack + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_4bit_rowwise_offsets + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size + - name: aten::to +- name: quantized::embedding_bag_4bit_unpack + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_byte + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::embedding_bag_byte_prepack + depends: + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_byte_rowwise_offsets + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_byte_unpack + depends: + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::size +- name: quantized::embedding_bag_prepack + depends: + - name: aten::_empty_per_channel_affine_quantized + - name: aten::contiguous + - name: aten::empty + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::qscheme + - name: aten::select + - name: aten::set_ + - name: aten::size + - name: aten::to +- name: quantized::embedding_bag_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::group_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point +- name: quantized::hardswish + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::size + - name: aten::to +- name: quantized::instance_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point +- name: quantized::layer_norm + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point +- name: quantized::linear + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_dynamic + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_dynamic_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_prepack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_prepack_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_prepack_fp16_legacy + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_prepack_legacy + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::eq + - name: aten::equal + - name: aten::is_nonzero + - name: aten::item + - name: aten::max + - name: aten::min + - name: aten::mul + - name: aten::q_per_channel_scales + - name: aten::q_per_channel_zero_points + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::quantize_per_tensor + - name: aten::resize_ + - name: aten::select + - name: aten::size + - name: aten::to + - name: aten::zeros +- name: quantized::linear_relu + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_relu_dynamic + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_unpack + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::linear_unpack_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::make_quantized_cell_params + depends: + - name: aten::eq + - name: aten::fbgemm_linear_int8_weight_fp32_activation + - name: aten::is_nonzero +- name: quantized::make_quantized_cell_params_dynamic + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::make_quantized_cell_params_fp16 + depends: + - name: aten::eq + - name: aten::is_nonzero +- name: quantized::max_pool2d + depends: + - name: aten::_empty_affine_quantized + - name: aten::contiguous + - name: aten::eq + - name: aten::is_nonzero + - name: aten::max_pool2d + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::size +- name: quantized::mul + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::mul_out + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::mul_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::mul_relu_out + depends: + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::mul_scalar + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::mul_scalar_out + depends: + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::mul_scalar_relu + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::mul_scalar_relu_out + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::contiguous + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::item + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::qscheme + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::set_quantizer_ + - name: aten::to +- name: quantized::quantized_gru_cell_dynamic + depends: + - name: aten::_thnn_fused_gru_cell + - name: aten::add + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul_ + - name: aten::sigmoid_ + - name: aten::sub + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: quantized::quantized_lstm_cell_dynamic + depends: + - name: aten::_thnn_fused_lstm_cell + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::mul + - name: aten::sigmoid_ + - name: aten::tanh + - name: aten::tanh_ + - name: aten::unsafe_chunk +- name: quantized::quantized_rnn_relu_cell_dynamic + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::relu +- name: quantized::quantized_rnn_tanh_cell_dynamic + depends: + - name: aten::add_ + - name: aten::eq + - name: aten::is_nonzero + - name: aten::tanh +- name: quantized::relu6 + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to +- name: quantized::threshold + depends: + - name: aten::_empty_affine_quantized + - name: aten::as_strided_ + - name: aten::copy_ + - name: aten::empty + - name: aten::empty_like + - name: aten::empty_strided + - name: aten::eq + - name: aten::is_nonzero + - name: aten::q_scale + - name: aten::q_zero_point + - name: aten::resize_ + - name: aten::resize_as_ + - name: aten::to diff --git a/tools/code_coverage/README.md b/tools/code_coverage/README.md new file mode 100644 index 000000000000000..dfa0a3fe8fd19fb --- /dev/null +++ b/tools/code_coverage/README.md @@ -0,0 +1,59 @@ +# Code Coverage Tool for Pytorch + +## Overview + +This tool is designed for calculating code coverage for Pytorch project in both fbcode and oss. But it also goes beyond Pytorch, applying to other folders in fbcode. + +It’s an integrated tool. You can use this tool to build, run, and generate both file-level and line-level report with both C++ tests and Python tests. + +### Simple +* *Simple command to run:* + * `python oss_coverage.py ` +* *Argument --clean will do all the messy clean up things for you* + +### But Powerful + +* *Choose your own interested folder*: + * Choose the folder you want to collect coverage for + * Flexible: default folder is good enough, but you can choose one or more other folders +* *Run only the test you want:* + * use --run-only to run the tests you want + * apply to both cpp and python tests +* *Final report:* + * File-Level: The coverage for each file you are interested in + * Line-Level: The coverage for each line in each file you are interested in +* *More complex but flexible options:* + * Use different stages like --build, --run, --summary to achieve more flexible functionality + +## How to use + +This part will introduce about the arguments you can use when run this tool. The arguments are powerful, giving you full flexibility to do different work. +If you are not familiar with the procedure of generating code coverage report by using clang, read [Source-based Code Coverage](https://clang.llvm.org/docs/SourceBasedCodeCoverage.html) will be helpful. + + +## Examples + +First step is to set some experimental value. +``` +# pytorch folder, by default all the c++ binaries are in build/bin/ +export PYTORCH_FOLDER=... +# make sure llvm-cov is available, by default it is /usr/local/opt/llvm/bin +export LLVM_TOOL_PATH=... +``` + +then command will run all the tests in `build/bin/` and `test/` folder +``` +python oss_coverage.py +``` +Most times you don't want collect coverage for the entire Pytorch folder, use --interested-folder to report coverage only over the folder you want: +``` +python oss_coverage.py --interested-folder=aten +``` +Then, still in most cases, if you only run one or several test(s): +``` +python oss_coverage.py --run-only=atest +python oss_coverage.py --run-only atest basic test_nn.py +``` + +### For more complex arguments and functionality +*To Be Done* diff --git a/tools/code_coverage/oss_coverage.py b/tools/code_coverage/oss_coverage.py new file mode 100644 index 000000000000000..f05df5a9f55d7ba --- /dev/null +++ b/tools/code_coverage/oss_coverage.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +import time + +from package.oss.cov_json import get_json_report +from package.oss.init import initialization +from package.tool.summarize_jsons import summarize_jsons + + +def report_coverage() -> None: + start_time = time.time() + (options, test_list, interested_folders) = initialization() + # run cpp tests + get_json_report(test_list, options) + # collect coverage data from json profiles + if options.need_summary: + summarize_jsons(test_list, interested_folders, [""], start_time) + + +if __name__ == "__main__": + report_coverage() diff --git a/tools/code_coverage/package/__init__.py b/tools/code_coverage/package/__init__.py new file mode 100644 index 000000000000000..e69de29bb2d1d64 diff --git a/tools/code_coverage/package/oss/__init__.py b/tools/code_coverage/package/oss/__init__.py new file mode 100644 index 000000000000000..e69de29bb2d1d64 diff --git a/tools/code_coverage/package/oss/cov_json.py b/tools/code_coverage/package/oss/cov_json.py new file mode 100644 index 000000000000000..4602a6f1d45d22c --- /dev/null +++ b/tools/code_coverage/package/oss/cov_json.py @@ -0,0 +1,35 @@ +import time + +from ..tool import clang_coverage, gcc_coverage +from ..util.setting import Option, TestList, TestPlatform +from ..util.utils import check_compiler_type, get_cov_type, print_time +from .init import gcc_export_init +from .run import clang_run, gcc_run + + +def get_json_report(test_list: TestList, options: Option): + start_time = time.time() + cov_type = get_cov_type() + # TODO change to enum + check_compiler_type(cov_type) + if cov_type == "CLANG": + # run + if options.need_run: + clang_run(test_list) + # merge && export + if options.need_merge: + clang_coverage.merge(test_list, TestPlatform.OSS) + if options.need_export: + clang_coverage.export(test_list, TestPlatform.OSS) + elif cov_type == "GCC": + # run + if options.need_run: + gcc_run(test_list) + # export + if options.need_export: + gcc_export_init() + gcc_coverage.export() + + print_time( + "collect coverage for cpp tests take time: ", start_time, summary_time=True + ) diff --git a/tools/code_coverage/package/oss/init.py b/tools/code_coverage/package/oss/init.py new file mode 100644 index 000000000000000..d45c3b47efe78aa --- /dev/null +++ b/tools/code_coverage/package/oss/init.py @@ -0,0 +1,136 @@ +import argparse +import os +from typing import List, Optional, Tuple + +from ..util.setting import ( + JSON_FOLDER_BASE_DIR, + LOG_DIR, + Option, + Test, + TestList, + TestType, +) +from ..util.utils import ( + clean_up, + create_folder, + get_cov_type, + print_log, + raise_no_test_found_exception, + remove_file, + remove_folder, +) +from ..util.utils_init import add_arguments_utils, create_folders, get_options +from .utils import ( + clean_up_gcda, + get_llvm_tool_path, + get_oss_binary_folder, + get_pytorch_folder, +) + + +def initialization() -> Tuple[Option, TestList, List[str]]: + # create folder if not exists + create_folders() + # add arguments + parser = argparse.ArgumentParser() + parser = add_arguments_utils(parser) + parser = add_arguments_oss(parser) + # parse arguments + (options, args_interested_folder, args_run_only, arg_clean) = parse_arguments( + parser + ) + # clean up + if arg_clean: + clean_up_gcda() + clean_up() + # get test lists + test_list = get_test_list(args_run_only) + # get interested folder -- final report will only over these folders + interested_folders = get_interested_folder(args_interested_folder) + # print initialization information + print_init_info() + # remove last time's log + remove_file(os.path.join(LOG_DIR, "log.txt")) + return (options, test_list, interested_folders) + + +def add_arguments_oss(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument( + "--run-only", + help="only run certain test(s), for example: atest test_nn.py.", + nargs="*", + default=None, + ) + + return parser + + +def parse_arguments( + parser: argparse.ArgumentParser, +) -> Tuple[Option, Optional[List[str]], Optional[List[str]], Optional[bool]]: + # parse args + args = parser.parse_args() + # get option + options = get_options(args) + return (options, args.interested_folder, args.run_only, args.clean) + + +def get_test_list_by_type( + run_only: Optional[List[str]], test_type: TestType +) -> TestList: + test_list: TestList = [] + binary_folder = get_oss_binary_folder(test_type) + g = os.walk(binary_folder) + for _, _, file_list in g: + for file_name in file_list: + if run_only is not None and file_name not in run_only: + continue + # target pattern in oss is used in printing report -- which tests we have run + test: Test = Test( + name=file_name, + target_pattern=file_name, + test_set="", + test_type=test_type, + ) + test_list.append(test) + return test_list + + +def get_test_list(run_only: Optional[List[str]]) -> TestList: + test_list: TestList = [] + # add c++ test list + test_list.extend(get_test_list_by_type(run_only, TestType.CPP)) + # add python test list + py_run_only = run_only if run_only else ["run_test.py"] + test_list.extend(get_test_list_by_type(py_run_only, TestType.PY)) + + # not find any test to run + if not test_list: + raise_no_test_found_exception( + get_oss_binary_folder(TestType.CPP), get_oss_binary_folder(TestType.PY) + ) + return test_list + + +def get_interested_folder(arg_interested_folder: Optional[List[str]]) -> List[str]: + if arg_interested_folder is not None: + # if this argument is specified, just return itself + return arg_interested_folder + else: + return [""] + + +def gcc_export_init(): + remove_folder(JSON_FOLDER_BASE_DIR) + create_folder(JSON_FOLDER_BASE_DIR) + + +def print_init_info() -> None: + print_log("pytorch folder: ", get_pytorch_folder()) + print_log("cpp test binaries folder: ", get_oss_binary_folder(TestType.CPP)) + print_log("python test scripts folder: ", get_oss_binary_folder(TestType.PY)) + print_log("cov_type: ", get_cov_type()) + print_log( + "llvm tool folder (only for clang, if you are using gcov please ignore it): ", + get_llvm_tool_path(), + ) diff --git a/tools/code_coverage/package/oss/run.py b/tools/code_coverage/package/oss/run.py new file mode 100644 index 000000000000000..d734c0249867b12 --- /dev/null +++ b/tools/code_coverage/package/oss/run.py @@ -0,0 +1,29 @@ +import os +import time + +from ..tool import clang_coverage, gcc_coverage +from ..util.setting import TestList, TestPlatform +from ..util.utils import get_raw_profiles_folder, print_time +from .utils import get_oss_binary_file + + +def clang_run(tests: TestList) -> None: + start_time = time.time() + for test in tests: + # raw_file + raw_file = os.path.join(get_raw_profiles_folder(), test.name + ".profraw") + # binary file + binary_file = get_oss_binary_file(test.name, test.test_type) + clang_coverage.run_target( + binary_file, raw_file, test.test_type, TestPlatform.OSS + ) + print_time("running binaries takes time: ", start_time, summary_time=True) + + +def gcc_run(tests: TestList) -> None: + start_time = time.time() + for test in tests: + # binary file + binary_file = get_oss_binary_file(test.name, test.test_type) + gcc_coverage.run_target(binary_file, test.test_type) + print_time("run binaries takes time: ", start_time, summary_time=True) diff --git a/tools/code_coverage/package/oss/utils.py b/tools/code_coverage/package/oss/utils.py new file mode 100644 index 000000000000000..13e7381f8ff9779 --- /dev/null +++ b/tools/code_coverage/package/oss/utils.py @@ -0,0 +1,71 @@ +import os +import subprocess +from typing import List + +from ..util.setting import SCRIPT_FOLDER, TestType +from ..util.utils import print_error, remove_file + + +def get_oss_binary_folder(test_type: TestType) -> str: + assert test_type in {TestType.CPP, TestType.PY} + # TODO: change the way we get binary file -- binary may not in build/bin ? + return os.path.join( + get_pytorch_folder(), "build/bin" if test_type == TestType.CPP else "test" + ) + + +def get_oss_shared_library() -> List[str]: + lib_dir = os.path.join(get_pytorch_folder(), "build", "lib") + return [ + os.path.join(lib_dir, lib) + for lib in os.listdir(lib_dir) + if lib.endswith(".dylib") + ] + + +def get_oss_binary_file(test_name: str, test_type: TestType) -> str: + assert test_type in {TestType.CPP, TestType.PY} + binary_folder = get_oss_binary_folder(test_type) + binary_file = os.path.join(binary_folder, test_name) + if test_type == TestType.PY: + # add python to the command so we can directly run the script by using binary_file variable + binary_file = "python " + binary_file + return binary_file + + +def get_llvm_tool_path() -> str: + return os.environ.get( + "LLVM_TOOL_PATH", "/usr/local/opt/llvm/bin" + ) # set default as llvm path in dev server, on mac the default may be /usr/local/opt/llvm/bin + + +def get_pytorch_folder() -> str: + return os.environ.get("PYTORCH_FOLDER", SCRIPT_FOLDER) + + +def clean_up_gcda() -> None: + gcda_files = get_gcda_files() + for item in gcda_files: + remove_file(item) + + +def get_gcda_files() -> List[str]: + folder_has_gcda = os.path.join(get_pytorch_folder(), "build") + if os.path.isdir(folder_has_gcda): + # TODO use glob + # output = glob.glob(f"{folder_has_gcda}/**/*.gcda") + output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"]) + output = output.decode("utf-8").split("\n") + return output + else: + return [] + + +def run_oss_python_test(binary_file: str) -> None: + # python test script + try: + subprocess.check_call( + binary_file, shell=True, cwd=get_oss_binary_folder(TestType.PY) + ) + except subprocess.CalledProcessError: + print_error(f"Binary failed to run: {binary_file}") diff --git a/tools/code_coverage/package/tool/__init__.py b/tools/code_coverage/package/tool/__init__.py new file mode 100644 index 000000000000000..e69de29bb2d1d64 diff --git a/tools/code_coverage/package/tool/clang_coverage.py b/tools/code_coverage/package/tool/clang_coverage.py new file mode 100644 index 000000000000000..56777e1d3527070 --- /dev/null +++ b/tools/code_coverage/package/tool/clang_coverage.py @@ -0,0 +1,175 @@ +import os +import subprocess +import time +from typing import List + +from ..util.setting import ( + JSON_FOLDER_BASE_DIR, + MERGED_FOLDER_BASE_DIR, + TestList, + TestPlatform, + TestType, +) +from ..util.utils import ( + check_platform_type, + convert_to_relative_path, + create_folder, + get_raw_profiles_folder, + get_test_name_from_whole_path, + print_log, + print_time, + related_to_test_list, + replace_extension, +) +from .utils import get_tool_path_by_platform, run_cpp_test + + +def create_corresponding_folder( + cur_path: str, prefix_cur_path: str, dir_list: List[str], new_base_folder: str +) -> None: + for dir_name in dir_list: + relative_path = convert_to_relative_path( + cur_path, prefix_cur_path + ) # get folder name like 'aten' + new_folder_path = os.path.join(new_base_folder, relative_path, dir_name) + create_folder(new_folder_path) + + +def run_target( + binary_file: str, raw_file: str, test_type: TestType, platform_type: TestPlatform +) -> None: + print_log("start run: ", binary_file) + # set environment variable -- raw profile output path of the binary run + os.environ["LLVM_PROFILE_FILE"] = raw_file + # run binary + if test_type == TestType.PY and platform_type == TestPlatform.OSS: + from ..oss.utils import run_oss_python_test + + run_oss_python_test(binary_file) + else: + run_cpp_test(binary_file) + + +def merge_target(raw_file: str, merged_file: str, platform_type: TestPlatform) -> None: + print_log("start to merge target: ", raw_file) + # run command + llvm_tool_path = get_tool_path_by_platform(platform_type) + subprocess.check_call( + [ + f"{llvm_tool_path}/llvm-profdata", + "merge", + "-sparse", + raw_file, + "-o", + merged_file, + ] + ) + + +def export_target( + merged_file: str, + json_file: str, + binary_file: str, + shared_library_list: List[str], + platform_type: TestPlatform, +) -> None: + if binary_file is None: + raise Exception(f"{merged_file} doesn't have corresponding binary!") + print_log("start to export: ", merged_file) + # run export + cmd_shared_library = ( + "" + if not shared_library_list + else f" -object {' -object '.join(shared_library_list)}" + ) + # if binary_file = "", then no need to add it (python test) + cmd_binary = "" if not binary_file else f" -object {binary_file} " + llvm_tool_path = get_tool_path_by_platform(platform_type) + + cmd = f"{llvm_tool_path}/llvm-cov export {cmd_binary} {cmd_shared_library} -instr-profile={merged_file} > {json_file}" + os.system(cmd) + + +def merge(test_list: TestList, platform_type: TestPlatform) -> None: + print("start merge") + start_time = time.time() + # find all raw profile under raw_folder and sub-folders + raw_folder_path = get_raw_profiles_folder() + g = os.walk(raw_folder_path) + for path, dir_list, file_list in g: + # if there is a folder raw/aten/, create corresponding merged folder profile/merged/aten/ if not exists yet + create_corresponding_folder( + path, raw_folder_path, dir_list, MERGED_FOLDER_BASE_DIR + ) + # check if we can find raw profile under this path's folder + for file_name in file_list: + if file_name.endswith(".profraw"): + if not related_to_test_list(file_name, test_list): + continue + print(f"start merge {file_name}") + raw_file = os.path.join(path, file_name) + merged_file_name = replace_extension(file_name, ".merged") + merged_file = os.path.join( + MERGED_FOLDER_BASE_DIR, + convert_to_relative_path(path, raw_folder_path), + merged_file_name, + ) + merge_target(raw_file, merged_file, platform_type) + print_time("merge take time: ", start_time, summary_time=True) + + +def export(test_list: TestList, platform_type: TestPlatform) -> None: + print("start export") + start_time = time.time() + # find all merged profile under merged_folder and sub-folders + g = os.walk(MERGED_FOLDER_BASE_DIR) + for path, dir_list, file_list in g: + # create corresponding merged folder in [json folder] if not exists yet + create_corresponding_folder( + path, MERGED_FOLDER_BASE_DIR, dir_list, JSON_FOLDER_BASE_DIR + ) + # check if we can find merged profile under this path's folder + for file_name in file_list: + if file_name.endswith(".merged"): + if not related_to_test_list(file_name, test_list): + continue + print(f"start export {file_name}") + # merged file + merged_file = os.path.join(path, file_name) + # json file + json_file_name = replace_extension(file_name, ".json") + json_file = os.path.join( + JSON_FOLDER_BASE_DIR, + convert_to_relative_path(path, MERGED_FOLDER_BASE_DIR), + json_file_name, + ) + check_platform_type(platform_type) + # binary file and shared library + binary_file = "" + shared_library_list = [] + if platform_type == TestPlatform.FBCODE: + from ..fbcode.utils import get_fbcode_binary_folder + + binary_file = os.path.join( + get_fbcode_binary_folder(path), + get_test_name_from_whole_path(merged_file), + ) + elif platform_type == TestPlatform.OSS: + from ..oss.utils import get_oss_binary_file, get_oss_shared_library + + test_name = get_test_name_from_whole_path(merged_file) + # if it is python test, no need to provide binary, shared library is enough + binary_file = ( + "" + if test_name.endswith(".py") + else get_oss_binary_file(test_name, TestType.CPP) + ) + shared_library_list = get_oss_shared_library() + export_target( + merged_file, + json_file, + binary_file, + shared_library_list, + platform_type, + ) + print_time("export take time: ", start_time, summary_time=True) diff --git a/tools/code_coverage/package/tool/gcc_coverage.py b/tools/code_coverage/package/tool/gcc_coverage.py new file mode 100644 index 000000000000000..d8de71aab6c1b4d --- /dev/null +++ b/tools/code_coverage/package/tool/gcc_coverage.py @@ -0,0 +1,49 @@ +import os +import subprocess +import time +from typing import Dict + +# gcc is only used in oss +from ..oss.utils import get_gcda_files, run_oss_python_test +from ..util.setting import JSON_FOLDER_BASE_DIR, TestType +from ..util.utils import print_log, print_time +from .utils import run_cpp_test + + +def update_gzip_dict(gzip_dict: Dict[str, int], file_name: str) -> str: + file_name = file_name.lower() + gzip_dict[file_name] = gzip_dict.get(file_name, 0) + 1 + num = gzip_dict[file_name] + return str(num) + "_" + file_name + + +def run_target(binary_file: str, test_type: TestType) -> None: + print_log("start run", test_type.value, "test: ", binary_file) + start_time = time.time() + assert test_type in {TestType.CPP, TestType.PY} + if test_type == TestType.CPP: + run_cpp_test(binary_file) + else: + run_oss_python_test(binary_file) + + print_time(" time: ", start_time) + + +def export() -> None: + start_time = time.time() + # collect .gcda files + gcda_files = get_gcda_files() + # file name like utils.cpp may have same name in different folder + gzip_dict: Dict[str, int] = {} + for gcda_item in gcda_files: + # generate json.gz + subprocess.check_call(["gcov", "-i", gcda_item]) + # cp json.gz to profile/json folder + gz_file_name = os.path.basename(gcda_item) + ".gcov.json.gz" + new_file_path = os.path.join( + JSON_FOLDER_BASE_DIR, update_gzip_dict(gzip_dict, gz_file_name) + ) + os.rename(gz_file_name, new_file_path) + # unzip json.gz to json + subprocess.check_output(["gzip", "-d", new_file_path]) + print_time("export take time: ", start_time, summary_time=True) diff --git a/tools/code_coverage/package/tool/parser/__init__.py b/tools/code_coverage/package/tool/parser/__init__.py new file mode 100644 index 000000000000000..e69de29bb2d1d64 diff --git a/tools/code_coverage/package/tool/parser/coverage_record.py b/tools/code_coverage/package/tool/parser/coverage_record.py new file mode 100644 index 000000000000000..1d6698aa861c413 --- /dev/null +++ b/tools/code_coverage/package/tool/parser/coverage_record.py @@ -0,0 +1,14 @@ +import typing as t + + +class CoverageRecord(t.NamedTuple): + filepath: str + covered_lines: t.List[int] + uncovered_lines: t.Optional[t.List[int]] = None + + def to_dict(self) -> t.Dict[str, t.Any]: + return { + "filepath": self.filepath, + "covered_lines": self.covered_lines, + "uncovered_lines": self.uncovered_lines, + } diff --git a/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py b/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py new file mode 100644 index 000000000000000..4d06cce1f953793 --- /dev/null +++ b/tools/code_coverage/package/tool/parser/gcov_coverage_parser.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, List, Set + +from .coverage_record import CoverageRecord + + +class GcovCoverageParser: + """ + Accepts a parsed json produced by gcov --json-format -- typically, + representing a single C++ test and produces a list + of CoverageRecord(s). + """ + + def __init__(self, llvm_coverage: Dict[str, Any]) -> None: + self._llvm_coverage = llvm_coverage + + @staticmethod + def _skip_coverage(path: str) -> bool: + """ + Returns True if file path should not be processed. + This is repo-specific and only makes sense for the current state of + ovrsource. + """ + if "third-party" in path: + return True + return False + + def parse(self) -> List[CoverageRecord]: + # The JSON format is described in the gcov source code + # https://gcc.gnu.org/onlinedocs/gcc/Invoking-Gcov.html + records: List[CoverageRecord] = [] + for file_info in self._llvm_coverage["files"]: + filepath = file_info["file"] + if self._skip_coverage(filepath): + continue + # parse json file + covered_lines: Set[int] = set() + uncovered_lines: Set[int] = set() + for line in file_info["lines"]: + line_number = line["line_number"] + count = line["count"] + if count == 0: + uncovered_lines.update([line_number]) + else: + covered_lines.update([line_number]) + + records.append( + CoverageRecord(filepath, sorted(covered_lines), sorted(uncovered_lines)) + ) + + return records diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py b/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py new file mode 100644 index 000000000000000..8d0cc21d1f550b4 --- /dev/null +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_parser.py @@ -0,0 +1,69 @@ +from typing import Any, Dict, List, Set, Tuple + +from .coverage_record import CoverageRecord +from .llvm_coverage_segment import LlvmCoverageSegment, parse_segments + + +class LlvmCoverageParser: + """ + Accepts a parsed json produced by llvm-cov export -- typically, + representing a single C++ test and produces a list + of CoverageRecord(s). + + """ + + def __init__(self, llvm_coverage: Dict[str, Any]) -> None: + self._llvm_coverage = llvm_coverage + + @staticmethod + def _skip_coverage(path: str) -> bool: + """ + Returns True if file path should not be processed. + This is repo-specific and only makes sense for the current state of + ovrsource. + """ + if "/third-party/" in path: + return True + return False + + @staticmethod + def _collect_coverage( + segments: List[LlvmCoverageSegment], + ) -> Tuple[List[int], List[int]]: + """ + Stateful parsing of coverage segments. + """ + covered_lines: Set[int] = set() + uncovered_lines: Set[int] = set() + prev_segment = LlvmCoverageSegment(1, 0, 0, 0, 0, None) + for segment in segments: + covered_range, uncovered_range = segment.get_coverage(prev_segment) + covered_lines.update(covered_range) + uncovered_lines.update(uncovered_range) + prev_segment = segment + + uncovered_lines.difference_update(covered_lines) + return sorted(covered_lines), sorted(uncovered_lines) + + def parse(self, repo_name: str) -> List[CoverageRecord]: + # The JSON format is described in the LLVM source code + # https://github.com/llvm-mirror/llvm/blob/master/tools/llvm-cov/CoverageExporterJson.cpp + records: List[CoverageRecord] = [] + for export_unit in self._llvm_coverage["data"]: + for file_info in export_unit["files"]: + filepath = file_info["filename"] + if self._skip_coverage(filepath): + continue + + if filepath is None: + continue + + segments = file_info["segments"] + + covered_lines, uncovered_lines = self._collect_coverage( + parse_segments(segments) + ) + + records.append(CoverageRecord(filepath, covered_lines, uncovered_lines)) + + return records diff --git a/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py new file mode 100644 index 000000000000000..7980b73fbe498b6 --- /dev/null +++ b/tools/code_coverage/package/tool/parser/llvm_coverage_segment.py @@ -0,0 +1,60 @@ +from typing import List, NamedTuple, Optional, Tuple + + +class LlvmCoverageSegment(NamedTuple): + line: int + col: int + segment_count: int + has_count: int + is_region_entry: int + is_gap_entry: Optional[int] + + @property + def has_coverage(self): + return self.segment_count > 0 + + @property + def is_executable(self): + return self.has_count > 0 + + def get_coverage( + self, prev_segment: "LlvmCoverageSegment" + ) -> Tuple[List[int], List[int]]: + # Code adapted from testpilot.testinfra.runners.gtestcoveragerunner.py + if not prev_segment.is_executable: + return [], [] + + # this segment ends at the line if col == 1 + # (so segment effectively ends on the line) and + # line+1 if col is > 1 (so it touches at least some part of last line). + end_of_segment = self.line if self.col == 1 else self.line + 1 + lines_range = list(range(prev_segment.line, end_of_segment)) + return (lines_range, []) if prev_segment.has_coverage else ([], lines_range) + + +def parse_segments(raw_segments: List[List[int]]) -> List[LlvmCoverageSegment]: + """ + Creates LlvmCoverageSegment from a list of lists in llvm export json. + each segment is represented by 5-element array. + """ + ret: List[LlvmCoverageSegment] = [] + for raw_segment in raw_segments: + assert ( + len(raw_segment) == 5 or len(raw_segment) == 6 + ), "list is not compatible with llvmcom export:" + " Expected to have 5 or 6 elements" + if len(raw_segment) == 5: + ret.append( + LlvmCoverageSegment( + raw_segment[0], + raw_segment[1], + raw_segment[2], + raw_segment[3], + raw_segment[4], + None, + ) + ) + else: + ret.append(LlvmCoverageSegment(*raw_segment)) + + return ret diff --git a/tools/code_coverage/package/tool/print_report.py b/tools/code_coverage/package/tool/print_report.py new file mode 100644 index 000000000000000..edfeee8086508c2 --- /dev/null +++ b/tools/code_coverage/package/tool/print_report.py @@ -0,0 +1,213 @@ +import os +import time +from typing import IO, Dict, List, Set + +from ..util.setting import SUMMARY_FOLDER_DIR, TestList, TestStatusType +from ..util.utils import convert_time + + +def key_by_percentage(x): + return x[1] + + +def key_by_name(x): + return x[0] + + +def is_intrested_file(file_path: str, interested_folders: List[str]): + if "cuda" in file_path: + return False + if "aten/gen_aten" in file_path or "aten/aten_" in file_path: + return False + for folder in interested_folders: + if folder in file_path: + return True + return False + + +def is_this_type_of_tests(target_name: str, test_set_by_type: Set[str]) -> bool: + # tests are divided into three types: success / partial success / fail to collect coverage + for test in test_set_by_type: + if target_name in test: + return True + return False + + +def print_test_by_type( + tests: TestList, test_set_by_type: Set[str], type_name: str, summary_file: IO +) -> None: + + print("Tests " + type_name + " to collect coverage:", file=summary_file) + for test in tests: + if is_this_type_of_tests(test.name, test_set_by_type): + print(test.target_pattern, file=summary_file) + print(file=summary_file) + + +def print_test_condition( + tests: TestList, + tests_type: TestStatusType, + interested_folders: List[str], + coverage_only: List[str], + summary_file: IO, + summary_type: str, +) -> None: + print_test_by_type(tests, tests_type["success"], "fully success", summary_file) + print_test_by_type(tests, tests_type["partial"], "partially success", summary_file) + print_test_by_type(tests, tests_type["fail"], "failed", summary_file) + print( + "\n\nCoverage Collected Over Interested Folders:\n", + interested_folders, + file=summary_file, + ) + print( + "\n\nCoverage Compilation Flags Only Apply To: \n", + coverage_only, + file=summary_file, + ) + print( + "\n\n---------------------------------- " + + summary_type + + " ----------------------------------", + file=summary_file, + ) + + +def line_oriented_report( + tests: TestList, + tests_type: TestStatusType, + interested_folders: List[str], + coverage_only: List[str], + covered_lines: Dict[str, Set[int]], + uncovered_lines: Dict[str, Set[int]], +) -> None: + with open(os.path.join(SUMMARY_FOLDER_DIR, "line_summary"), "w+") as report_file: + print_test_condition( + tests, + tests_type, + interested_folders, + coverage_only, + report_file, + "LINE SUMMARY", + ) + for file_name in covered_lines: + if len(covered_lines[file_name]) == 0: + covered = {} + else: + covered = covered_lines[file_name] + if len(uncovered_lines[file_name]) == 0: + uncovered = {} + else: + uncovered = uncovered_lines[file_name] + print( + f"{file_name}\n covered lines: {sorted(covered)}\n unconvered lines:{sorted(uncovered)}", + file=report_file, + ) + + +def print_total_program_time(start_time: float, summary_file: IO) -> None: + end_time = time.time() + # print to summary file + print( + f"PROGRAM RUNNING TIME: {convert_time(end_time - start_time)}\n\n", + file=summary_file, + ) + # print to terminal + print(f"time: {convert_time(end_time - start_time)}") + + +def print_file_summary( + covered_summary: int, total_summary: int, summary_file: IO +) -> None: + # print summary first + try: + coverage_percentage = round(1.0 * covered_summary / total_summary * 100, 2) + except ZeroDivisionError: + raise ZeroDivisionError( + "Failed to generate coverage report, please check if json profiles are valid in profile/json" + ) + print( + f"SUMMARY\ncovered: {covered_summary}\nuncovered: {total_summary}\npercentage: {coverage_percentage}%\n\n", + file=summary_file, + ) + + +def print_file_oriented_report( + tests_type: TestStatusType, + coverage, + covered_summary: int, + total_summary: int, + summary_file: IO, + tests: TestList, + interested_folders: List[str], + coverage_only: List[str], + program_start_time: float, +) -> None: + print_file_summary(covered_summary, total_summary, summary_file) + print_total_program_time(program_start_time, summary_file) + # print test condition (interested folder / tests that are successsful or failed) + print_test_condition( + tests, + tests_type, + interested_folders, + coverage_only, + summary_file, + "FILE SUMMARY", + ) + # print each file's information + for item in coverage: + print( + item[0].ljust(75), + (str(item[1]) + "%").rjust(10), + str(item[2]).rjust(10), + str(item[3]).rjust(10), + file=summary_file, + ) + + print( + f"summary percentage:{round(1.0 * covered_summary / total_summary * 100, 2)}%" + ) + + +def file_oriented_report( + tests: TestList, + tests_type: TestStatusType, + interested_folders: List[str], + coverage_only: List[str], + program_start_time: float, + covered_lines: Dict[str, Set[int]], + uncovered_lines: Dict[str, Set[int]], +) -> None: + with open(os.path.join(SUMMARY_FOLDER_DIR, "file_summary"), "w+") as summary_file: + start_time = time.time() + covered_summary = 0 + total_summary = 0 + coverage = [] + for file_name in covered_lines: + # get coverage number for this file + covered_count = len(covered_lines[file_name]) + total_count = covered_count + len(uncovered_lines[file_name]) + try: + percentage = round(covered_count / total_count * 100, 2) + except ZeroDivisionError: + percentage = 0 + # store information in a list to be sorted + coverage.append([file_name, percentage, covered_count, total_count]) + # update summary + covered_summary = covered_summary + covered_count + total_summary = total_summary + total_count + # sort + coverage.sort(key=key_by_name) + coverage.sort(key=key_by_percentage) + # print + print_file_oriented_report( + tests_type, + coverage, + covered_summary, + total_summary, + summary_file, + tests, + interested_folders, + coverage_only, + program_start_time, + ) diff --git a/tools/code_coverage/package/tool/summarize_jsons.py b/tools/code_coverage/package/tool/summarize_jsons.py new file mode 100644 index 000000000000000..2b023814343071b --- /dev/null +++ b/tools/code_coverage/package/tool/summarize_jsons.py @@ -0,0 +1,174 @@ +import json +import os +import time +from typing import Any, Dict, List, Optional, Set, Tuple + +from ..util.setting import JSON_FOLDER_BASE_DIR, TestList, TestStatusType +from ..util.utils import ( + check_compiler_type, + get_cov_type, + print_error, + print_time, + related_to_test_list, +) +from .parser.coverage_record import CoverageRecord +from .parser.gcov_coverage_parser import GcovCoverageParser +from .parser.llvm_coverage_parser import LlvmCoverageParser +from .print_report import file_oriented_report, line_oriented_report + + +# coverage_records: Dict[str, LineInfo] = dict() +covered_lines: Dict[str, Set[int]] = {} +uncovered_lines: Dict[str, Set[int]] = {} +tests_type: TestStatusType = {"success": set(), "partial": set(), "fail": set()} + + +def transform_file_name(file_path: str, interested_folders: List[str]) -> Optional[str]: + remove_patterns: Set[str] = {".DEFAULT.cpp", ".AVX.cpp", ".AVX2.cpp"} + for pattern in remove_patterns: + file_path = file_path.replace(pattern, "") + # if have interested folder + for folder in interested_folders: + if folder in file_path: + return file_path[file_path.find(folder) :] + return "" + + +def is_intrested_file(file_path: str, interested_folders: List[str]): + ignored_patterns = ["cuda", "aten/gen_aten", "aten/aten_", "build/"] + if any([pattern in file_path for pattern in ignored_patterns]): + return False + + for folder in interested_folders: + i_folder = folder if folder.endswith("/") else f"{folder}/" + if i_folder in file_path: + return True + return False + + +def get_json_obj(json_file: str) -> Tuple[Any, int]: + """ + Sometimes at the start of file llvm/gcov will complains "fail to find coverage data", + then we need to skip these lines + -- success read: 0 - this json file have the full json coverage information + -- partial success: 1 - this json file starts with some error prompt, but still have the coverage information + -- fail to read: 2 - this json file doesn't have any coverage information + """ + read_status = -1 + with open(json_file) as f: + lines = f.readlines() + for line in lines: + try: + json_obj = json.loads(line) + except json.JSONDecodeError: + read_status = 1 + continue + else: + if read_status == -1: + # not meet jsonDecoderError before, return success + read_status = 0 + return (json_obj, read_status) + return None, 2 + + +def parse_json(json_file: str) -> List[CoverageRecord]: + print("start parse:", json_file) + json_obj, read_status = get_json_obj(json_file) + if read_status == 0: + tests_type["success"].add(json_file) + elif read_status == 1: + tests_type["partial"].add(json_file) + else: + tests_type["fail"].add(json_file) + raise RuntimeError( + "Fail to do code coverage! Fail to load json file: ", json_file + ) + cov_type = get_cov_type() + check_compiler_type(cov_type) + if cov_type == "CLANG": + coverage_records = LlvmCoverageParser(json_obj).parse("fbcode") + # print(coverage_records) + elif cov_type == "GCC": + coverage_records = GcovCoverageParser(json_obj).parse() + + return coverage_records + + +def parse_jsons(test_list: TestList, interested_folders: List[str]) -> None: + g = os.walk(JSON_FOLDER_BASE_DIR) + + for path, _, file_list in g: + for file_name in file_list: + if file_name.endswith(".json"): + if not related_to_test_list(file_name, test_list): + continue + json_file = os.path.join(path, file_name) + try: + coverage_records = parse_json(json_file) + except RuntimeError: + print_error("Fail to load json file: ", json_file) + continue + # collect information from each target's export file and merge them together: + update_coverage(coverage_records, interested_folders) + + +def update_coverage( + coverage_records: List[CoverageRecord], interested_folders: List[str] +) -> None: + for item in coverage_records: + # extract information for the record + record = item.to_dict() + file_path = record["filepath"] + if not is_intrested_file(file_path, interested_folders): + continue + covered_range = record["covered_lines"] + uncovered_range = record["uncovered_lines"] + # transform file name: remote/13223/caffe2/aten -> caffe2/aten + file_path = transform_file_name(file_path, interested_folders) + if file_path is None: + continue + # if file not exists, add it into dictionary + if file_path not in covered_lines: + covered_lines[file_path] = set() + if file_path not in uncovered_lines: + uncovered_lines[file_path] = set() + # update this file's covered and uncovered lines + if covered_range is not None: + covered_lines[file_path].update(covered_range) + if uncovered_range is not None: + uncovered_lines[file_path].update(uncovered_range) + + +def update_set() -> None: + for file_name in covered_lines: + # difference_update + uncovered_lines[file_name].difference_update(covered_lines[file_name]) + + +def summarize_jsons( + test_list: TestList, + interested_folders: List[str], + coverage_only: List[str], + program_start_time: float, +) -> None: + start_time = time.time() + parse_jsons(test_list, interested_folders) + update_set() + line_oriented_report( + test_list, + tests_type, + interested_folders, + coverage_only, + covered_lines, + uncovered_lines, + ) + file_oriented_report( + test_list, + tests_type, + interested_folders, + coverage_only, + program_start_time, + covered_lines, + uncovered_lines, + ) + print_time("summary jsons take time: ", start_time) diff --git a/tools/code_coverage/package/tool/utils.py b/tools/code_coverage/package/tool/utils.py new file mode 100644 index 000000000000000..221bac94c884990 --- /dev/null +++ b/tools/code_coverage/package/tool/utils.py @@ -0,0 +1,23 @@ +import subprocess + +from ..util.setting import TestPlatform +from ..util.utils import print_error + + +def run_cpp_test(binary_file: str) -> None: + # cpp test binary + try: + subprocess.check_call(binary_file) + except subprocess.CalledProcessError: + print_error(f"Binary failed to run: {binary_file}") + + +def get_tool_path_by_platform(platform: TestPlatform): + if platform == TestPlatform.FBCODE: + from ..fbcode.utils import get_llvm_tool_path + + return get_llvm_tool_path() + else: + from ..oss.utils import get_llvm_tool_path + + return get_llvm_tool_path() diff --git a/tools/code_coverage/package/util/__init__.py b/tools/code_coverage/package/util/__init__.py new file mode 100644 index 000000000000000..e69de29bb2d1d64 diff --git a/tools/code_coverage/package/util/setting.py b/tools/code_coverage/package/util/setting.py new file mode 100644 index 000000000000000..59e4fad3709b945 --- /dev/null +++ b/tools/code_coverage/package/util/setting.py @@ -0,0 +1,62 @@ +import os +from enum import Enum +from typing import Dict, List, Set + + +# +HOME_DIR = os.environ["HOME"] +setting_file_path = os.path.realpath(__file__) +SCRIPT_FOLDER = os.path.join( + os.path.dirname(setting_file_path), os.path.pardir, os.path.pardir +) + + +# +PROFILE_DIR = os.path.join(SCRIPT_FOLDER, "profile") +JSON_FOLDER_BASE_DIR = os.path.join(PROFILE_DIR, "json") +MERGED_FOLDER_BASE_DIR = os.path.join(PROFILE_DIR, "merged") +SUMMARY_FOLDER_DIR = os.path.join(PROFILE_DIR, "summary") + +# +LOG_DIR = os.path.join(PROFILE_DIR, "log") + + +# test type, DO NOT change the name, it should be consistent with [buck query --output-attribute] result +class TestType(Enum): + CPP: str = "cxx_test" + PY: str = "python_test" + + +class Test: + name: str + target_pattern: str + test_set: str # like __aten__ + test_type: TestType + + def __init__( + self, name: str, target_pattern: str, test_set: str, test_type: TestType + ) -> None: + self.name = name + self.target_pattern = target_pattern + self.test_set = test_set + self.test_type = test_type + + +TestList = List[Test] +TestStatusType = Dict[str, Set[str]] + + +# option +class Option: + need_build: bool = False + need_run: bool = False + need_merge: bool = False + need_export: bool = False + need_summary: bool = False + need_pytest: bool = False + + +# test platform +class TestPlatform(Enum): + FBCODE: str = "fbcode" + OSS: str = "oss" diff --git a/tools/code_coverage/package/util/utils.py b/tools/code_coverage/package/util/utils.py new file mode 100644 index 000000000000000..cb5fc45cc3775db --- /dev/null +++ b/tools/code_coverage/package/util/utils.py @@ -0,0 +1,126 @@ +import os +import shutil +import sys +import time +from typing import Any + +from .setting import LOG_DIR, PROFILE_DIR, TestList, TestPlatform, TestType + + +def convert_time(seconds: float) -> str: + seconds = int(round(seconds)) + seconds = seconds % (24 * 3600) + hour = seconds // 3600 + seconds %= 3600 + minutes = seconds // 60 + seconds %= 60 + + return "%d:%02d:%02d" % (hour, minutes, seconds) + + +def print_time(message: str, start_time: float, summary_time: bool = False) -> None: + with open(os.path.join(LOG_DIR, "log.txt"), "a+") as log_file: + end_time = time.time() + print(message, convert_time(end_time - start_time), file=log_file) + if summary_time: + print("\n", file=log_file) + + +def print_log(*args: Any) -> None: + with open(os.path.join(LOG_DIR, "log.txt"), "a+") as log_file: + print(f"[LOG] {' '.join(args)}", file=log_file) + + +def print_error(*args: Any) -> None: + with open(os.path.join(LOG_DIR, "log.txt"), "a+") as log_file: + print(f"[ERROR] {' '.join(args)}", file=log_file) + + +def remove_file(path: str) -> None: + if os.path.exists(path): + os.remove(path) + + +def remove_folder(path: str) -> None: + shutil.rmtree(path) + + +def create_folder(*paths: Any) -> None: + for path in paths: + os.makedirs(path, exist_ok=True) + + +# clean up all the files generated by coverage tool +def clean_up() -> None: + # remove profile folder + remove_folder(PROFILE_DIR) + sys.exit("Clean Up Successfully!") + + +def convert_to_relative_path(whole_path: str, base_path: str) -> str: + # ("profile/raw", "profile") -> "raw" + if base_path not in whole_path: + raise RuntimeError(base_path + " is not in " + whole_path) + return whole_path[len(base_path) + 1 :] + + +def replace_extension(filename, ext): + return filename[: filename.rfind(".")] + ext + + +# a file is related if it's in one of the test_list folder +def related_to_test_list(file_name: str, test_list: TestList) -> bool: + for test in test_list: + if test.name in file_name: + return True + return False + + +def get_raw_profiles_folder() -> str: + return os.environ.get("RAW_PROFILES_FOLDER", os.path.join(PROFILE_DIR, "raw")) + + +# TODO auto detect +def get_cov_type() -> str: + return os.environ.get("COMPILER_TYPE", "CLANG") + + +def get_test_name_from_whole_path(path: str) -> str: + # code_coverage_tool/profile/merged/haha.merged -> haha + start = path.rfind("/") + end = path.rfind(".") + assert start >= 0 and end >= 0 + return path[start + 1 : end] + + +def check_compiler_type(cov_type: str) -> None: + if cov_type in ["CLANG", "GCC"]: + return + raise Exception( + f"Can't parse compiler type: {cov_type}.", + " Please set environment variable COMPILER_TYPE as CLANG or GCC", + ) + + +def check_platform_type(platform_type: TestPlatform) -> None: + if platform_type in [TestPlatform.OSS, TestPlatform.FBCODE]: + return + raise Exception( + f"Can't parse platform type: {platform_type}.", + " Please set environment variable COMPILER_TYPE as OSS or FBCODE", + ) + + +def check_test_type(test_type: str, target: str) -> None: + if test_type in [TestType.CPP.value, TestType.PY.value]: + return + raise Exception( + f"Can't parse test type: {test_type}.", + f" Please check the type of buck target: {target}", + ) + + +def raise_no_test_found_exception(cpp_binary_folder: str, python_binary_folder: str): + raise RuntimeError( + f"No cpp and python tests found in folder **{cpp_binary_folder} and **{python_binary_folder}**" + ) diff --git a/tools/code_coverage/package/util/utils_init.py b/tools/code_coverage/package/util/utils_init.py new file mode 100644 index 000000000000000..989da6435aef4fb --- /dev/null +++ b/tools/code_coverage/package/util/utils_init.py @@ -0,0 +1,101 @@ +import argparse +import os +from typing import Any + +from .setting import ( + JSON_FOLDER_BASE_DIR, + LOG_DIR, + MERGED_FOLDER_BASE_DIR, + PROFILE_DIR, + SUMMARY_FOLDER_DIR, + Option, +) +from .utils import create_folder, get_raw_profiles_folder, remove_file + + +def remove_files() -> None: + # remove log + remove_file(os.path.join(LOG_DIR, "log.txt")) + + +def create_folders() -> None: + create_folder( + PROFILE_DIR, + MERGED_FOLDER_BASE_DIR, + JSON_FOLDER_BASE_DIR, + get_raw_profiles_folder(), + SUMMARY_FOLDER_DIR, + LOG_DIR, + ) + + +def add_arguments_utils(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser.add_argument("--run", help="run the cpp test binaries", action="store_true") + parser.add_argument( + "--merge", + help="merge raw profiles (only apply to clang coverage)", + action="store_true", + ) + parser.add_argument( + "--export", help="generate json report for each file", action="store_true" + ) + parser.add_argument( + "--summary", + help="read json report and generate file/line-oriented summary", + action="store_true", + ) + parser.add_argument( + "--interested-folder", + help="Final report will be only about these folders and its sub-folders; for example: caff2/c10;", + nargs="+", + default=None, + ) + parser.add_argument( + "--clean", + help="delete all files generated by coverage tool", + action="store_true", + default=False, + ) + + return parser + + +def have_option(have_stage: bool, option: int) -> int: + if have_stage: + return option + else: + return 0 + + +def get_options(args: Any) -> Option: + option: Option = Option() + if args.__contains__("build"): + if args.build: + option.need_build = True + + if args.__contains__("run"): + if args.run: + option.need_run = True + + if args.__contains__("merge"): + if args.merge: + option.need_merge = True + + if args.__contains__("export"): + if args.export: + option.need_export = True + + if args.__contains__("summary"): + if args.summary: + option.need_summary = True + + # user does not have specified stage like run + if not any(vars(option).values()): + option.need_build = True + option.need_run = True + option.need_merge = True + option.need_export = True + option.need_summary = True + option.need_pytest = True + + return option diff --git a/tools/jit/gen_unboxing_wrappers.py b/tools/jit/gen_unboxing_wrappers.py index 0ba7d2bbbc37a27..5e275fac96950b1 100644 --- a/tools/jit/gen_unboxing_wrappers.py +++ b/tools/jit/gen_unboxing_wrappers.py @@ -166,11 +166,7 @@ def from_ivalue(arg, value): .layout(${layout}) .device(${device}) .pinned_memory(${pin_memory}); -#ifdef USE_STATIC_DISPATCH - auto result_ = at::${name}(${args_with_tensor_options}); -#else auto result_ = torch::${name}(${args_with_tensor_options}); -#endif """) CALL_METHOD_WITH_TENSOR_OPTIONS = CodeTemplate("""\ const auto options = TensorOptions() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9e1e0ba1b3e6a0c..92104122c76325e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -143,15 +143,11 @@ class Graph: class FunctionSchema: ... -# Defined in torch/jit/api/module.h -class ExtraFilesMap: - ... - # Defined in torch/csrc/jit/python/script_init.cpp class ScriptFunction: def __call__(self, *args, **kwargs) -> Tensor: ... - def save(self, filename: str, _extra_files: ExtraFilesMap) -> None: ... - def save_to_buffer(self, _extra_files = ExtraFilesMap) -> bytes: ... + def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ... + def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ... def graph(self) -> Graph: ... def inlined_graph(self) -> Graph: ... def schema(self) -> FunctionSchema: ... diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index a16502748b968b6..292823414b8aeb8 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -709,6 +709,29 @@ def __getitem__(self, types): for i in range(2, 7): globals()["BroadcastingList{}".format(i)] = BroadcastingList1 + +def is_scripting(): + r""" + Function that returns True when in compilation and False otherwise. This + is useful especially with the @unused decorator to leave code in your + model that is not yet TorchScript compatible. + .. testcode:: + + import torch + + @torch.jit.unused + def unsupported_linear_op(x): + return x + + def linear(x): + if not torch.jit.is_scripting(): + return torch.linear(x) + else: + return unsupported_linear_op(x) + """ + return False + + # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. def _qualified_name(obj): # This special case allows us to override the qualified name on a type. diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index f603b115b52c0e2..cdd1086bac3c068 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -544,6 +544,18 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.asinh` """) +add_docstr_all('arcsinh', r""" +arcsinh() -> Tensor + +See :func:`torch.arcsinh` +""") + +add_docstr_all('arcsinh_', r""" +arcsinh_() -> Tensor + +In-place version of :meth:`~Tensor.arcsinh` +""") + add_docstr_all('as_strided', r""" as_strided(size, stride, storage_offset=0) -> Tensor @@ -574,32 +586,42 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.arctan` """) -add_docstr_all('atan2', - r""" +add_docstr_all('atan2', r""" atan2(other) -> Tensor See :func:`torch.atan2` """) -add_docstr_all('atan2_', - r""" +add_docstr_all('atan2_', r""" atan2_(other) -> Tensor In-place version of :meth:`~Tensor.atan2` """) -add_docstr_all('atanh', - r""" +add_docstr_all('atanh', r""" atanh() -> Tensor See :func:`torch.atanh` """) -add_docstr_all('atanh_', - r""" +add_docstr_all('atanh_', r""" +atanh_(other) -> Tensor + In-place version of :meth:`~Tensor.atanh` """) +add_docstr_all('arctanh', r""" +arctanh() -> Tensor + +See :func:`torch.arctanh` +""") + +add_docstr_all('arctanh_', r""" +arctanh_(other) -> Tensor + +In-place version of :meth:`~Tensor.arctanh` +""") + add_docstr_all('baddbmm', r""" baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor @@ -3145,18 +3167,10 @@ def callable(a, b) -> number """) -add_docstr_all('sub', - r""" +add_docstr_all('sub', r""" sub(other, *, alpha=1) -> Tensor -Subtracts a scalar or tensor from :attr:`self` tensor. If both :attr:`alpha` -and :attr:`other` are specified, each element of :attr:`other` is scaled by -:attr:`alpha` before being used. - -When :attr:`other` is a tensor, the shape of :attr:`other` must be -:ref:`broadcastable ` with the shape of the underlying -tensor. - +See :func:`torch.sub`. """) add_docstr_all('sub_', @@ -3166,6 +3180,18 @@ def callable(a, b) -> number In-place version of :meth:`~Tensor.sub` """) +add_docstr_all('subtract', r""" +subtract(other, *, alpha=1) -> Tensor + +See :func:`torch.subtract`. +""") + +add_docstr_all('subtract_', r""" +subtract_(other, *, alpha=1) -> Tensor + +In-place version of :meth:`~Tensor.subtract`. +""") + add_docstr_all('sum', r""" sum(dim=None, keepdim=False, dtype=None) -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 27b47b4cee972df..c86ea77afa05e15 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -283,6 +283,9 @@ def merge_dicts(*dicts): .. math:: out = \beta\ \text{input} + \alpha\ (\sum_{i=0}^{b-1} \text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. """ + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -411,6 +414,9 @@ def merge_dicts(*dicts): .. math:: \text{out} = \beta\ \text{input} + \alpha\ (\text{mat1}_i \mathbin{@} \text{mat2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. """ + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -453,6 +459,9 @@ def merge_dicts(*dicts): .. math:: \text{out} = \beta\ \text{input} + \alpha\ (\text{mat} \mathbin{@} \text{vec}) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. """ + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers @@ -489,6 +498,9 @@ def merge_dicts(*dicts): .. math:: \text{out} = \beta\ \text{input} + \alpha\ (\text{vec1} \otimes \text{vec2}) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. """ + r""" If :attr:`vec1` is a vector of size `n` and :attr:`vec2` is a vector of size `m`, then :attr:`input` must be @@ -699,6 +711,12 @@ def merge_dicts(*dicts): tensor([ 0.1599, -1.1534, -0.9435, -0.8990 ]) """.format(**common_args)) +add_docstr(torch.arcsinh, r""" +arcsinh(input, *, out=None) -> Tensor + +Alias for :func:`torch.asinh`. +""") + add_docstr(torch.atan, r""" atan(input, *, out=None) -> Tensor @@ -758,8 +776,7 @@ def merge_dicts(*dicts): tensor([ 0.9833, 0.0811, -1.9743, -1.4151]) """.format(**common_args)) -add_docstr(torch.atanh, - r""" +add_docstr(torch.atanh, r""" atanh(input, *, out=None) -> Tensor Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. @@ -787,6 +804,12 @@ def merge_dicts(*dicts): tensor([ -1.7253, 0.3060, -1.2899, -0.1893 ]) """.format(**common_args)) +add_docstr(torch.arctanh, r""" +arctanh(input, *, out=None) -> Tensor + +Alias for :func:`torch.atanh`. +""") + add_docstr(torch.baddbmm, r""" baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor @@ -807,6 +830,9 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \beta\ \text{input}_i + \alpha\ (\text{batch1}_i \mathbin{@} \text{batch2}_i) + +If :attr:`beta` is 0, then :attr:`input` will be ignored, and `nan` and `inf` in +it will not be propagated. """ + r""" For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and :attr:`alpha` must be real numbers, otherwise they should be integers. @@ -2645,7 +2671,7 @@ def merge_dicts(*dicts): add_docstr(torch.floor, r""" -floor(input, out=None) -> Tensor +floor(input, *, out=None) -> Tensor Returns a new tensor with the floor of the elements of :attr:`input`, the largest integer less than or equal to each element. @@ -2655,6 +2681,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -2668,7 +2696,7 @@ def merge_dicts(*dicts): add_docstr(torch.floor_divide, r""" -floor_divide(input, other, out=None) -> Tensor +floor_divide(input, other, *, out=None) -> Tensor Return the division of the inputs rounded down to the nearest integer. See :func:`torch.div` for type promotion and broadcasting rules. @@ -2696,7 +2724,7 @@ def merge_dicts(*dicts): add_docstr(torch.fmod, r""" -fmod(input, other, out=None) -> Tensor +fmod(input, other, *, out=None) -> Tensor Computes the element-wise remainder of division. @@ -2709,6 +2737,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the dividend other (Tensor or float): the divisor, which may be either a number or a tensor of the same shape as the dividend + +Keyword args: {out} Example:: @@ -2723,7 +2753,7 @@ def merge_dicts(*dicts): add_docstr(torch.frac, r""" -frac(input, out=None) -> Tensor +frac(input, *, out=None) -> Tensor Computes the fractional portion of each element in :attr:`input`. @@ -2786,9 +2816,10 @@ def merge_dicts(*dicts): [5, 6, 7, 8]]) """.format(**common_args)) +# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.gather, r""" -gather(input, dim, index, out=None, sparse_grad=False) -> Tensor +gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor Gathers values along an axis specified by `dim`. @@ -2808,8 +2839,8 @@ def merge_dicts(*dicts): input (Tensor): the source tensor dim (int): the axis along which to index index (LongTensor): the indices of elements to gather - out (Tensor, optional): the destination tensor sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor. + out (Tensor, optional): the destination tensor Example:: @@ -2822,7 +2853,7 @@ def merge_dicts(*dicts): add_docstr(torch.gcd, r""" -gcd(input, other, out=None) -> Tensor +gcd(input, other, *, out=None) -> Tensor Computes the element-wise greatest common divisor (GCD) of :attr:`input` and :attr:`other`. @@ -2876,7 +2907,7 @@ def merge_dicts(*dicts): add_docstr(torch.geqrf, r""" -geqrf(input, out=None) -> (Tensor, Tensor) +geqrf(input, *, out=None) -> (Tensor, Tensor) This is a low-level function for calling LAPACK directly. This function returns a namedtuple (a, tau) as defined in `LAPACK documentation for geqrf`_ . @@ -2893,6 +2924,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the input matrix + +Keyword args: out (tuple, optional): the output tuple of (Tensor, Tensor) .. _LAPACK documentation for geqrf: @@ -2908,7 +2941,7 @@ def merge_dicts(*dicts): add_docstr(torch.ger, r""" -ger(input, vec2, out=None) -> Tensor +ger(input, vec2, *, out=None) -> Tensor Outer product of :attr:`input` and :attr:`vec2`. If :attr:`input` is a vector of size :math:`n` and :attr:`vec2` is a vector of @@ -2919,6 +2952,8 @@ def merge_dicts(*dicts): Args: input (Tensor): 1-D input vector vec2 (Tensor): 1-D input vector + +Keyword args: out (Tensor, optional): optional output matrix Example:: @@ -2934,7 +2969,7 @@ def merge_dicts(*dicts): add_docstr(torch.solve, r""" -torch.solve(input, A, out=None) -> (Tensor, Tensor) +torch.solve(input, A, *, out=None) -> (Tensor, Tensor) This function returns the solution to the system of linear equations represented by :math:`AX = B` and the LU factorization of @@ -2958,6 +2993,8 @@ def merge_dicts(*dicts): is zero or more batch dimensions. A (Tensor): input square matrix of size :math:`(*, m, m)`, where :math:`*` is zero or more batch dimensions. + +Keyword args: out ((Tensor, Tensor), optional): optional output tuple. Example:: @@ -3046,7 +3083,7 @@ def merge_dicts(*dicts): add_docstr(torch.histc, r""" -histc(input, bins=100, min=0, max=0, out=None) -> Tensor +histc(input, bins=100, min=0, max=0, *, out=None) -> Tensor Computes the histogram of a tensor. @@ -3061,6 +3098,8 @@ def merge_dicts(*dicts): bins (int): number of histogram bins min (int): lower end of the range (inclusive) max (int): upper end of the range (inclusive) + +Keyword args: {out} Returns: @@ -3100,7 +3139,7 @@ def merge_dicts(*dicts): add_docstr(torch.index_select, r""" -index_select(input, dim, index, out=None) -> Tensor +index_select(input, dim, index, *, out=None) -> Tensor Returns a new tensor which indexes the :attr:`input` tensor along dimension :attr:`dim` using the entries in :attr:`index` which is a `LongTensor`. @@ -3118,6 +3157,8 @@ def merge_dicts(*dicts): {input} dim (int): the dimension in which we index index (LongTensor): the 1-D tensor containing the indices to index + +Keyword args: {out} Example:: @@ -3139,7 +3180,7 @@ def merge_dicts(*dicts): add_docstr(torch.inverse, r""" -inverse(input, out=None) -> Tensor +inverse(input, *, out=None) -> Tensor Takes the inverse of the square matrix :attr:`input`. :attr:`input` can be batches of 2D square tensors, in which case this function would return a tensor composed of @@ -3153,6 +3194,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more batch dimensions + +Keyword args: {out} Example:: @@ -3374,7 +3417,7 @@ def merge_dicts(*dicts): add_docstr(torch.kthvalue, r""" -kthvalue(input, k, dim=None, keepdim=False, out=None) -> (Tensor, LongTensor) +kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the :attr:`k` th smallest element of each row of the :attr:`input` tensor in the given dimension @@ -3393,6 +3436,8 @@ def merge_dicts(*dicts): k (int): k for the k-th smallest element dim (int, optional): the dimension to find the kth value along {keepdim} + +Keyword args: out (tuple, optional): the output tuple of (Tensor, LongTensor) can be optionally given to be used as output buffers @@ -3414,7 +3459,7 @@ def merge_dicts(*dicts): add_docstr(torch.lcm, r""" -lcm(input, other, out=None) -> Tensor +lcm(input, other, *, out=None) -> Tensor Computes the element-wise least common multiple (LCM) of :attr:`input` and :attr:`other`. @@ -3469,7 +3514,7 @@ def merge_dicts(*dicts): add_docstr(torch.lerp, r""" -lerp(input, end, weight, out=None) +lerp(input, end, weight, *, out=None) Does a linear interpolation of two tensors :attr:`start` (given by :attr:`input`) and :attr:`end` based on a scalar or tensor :attr:`weight` and returns the resulting :attr:`out` tensor. @@ -3485,6 +3530,8 @@ def merge_dicts(*dicts): input (Tensor): the tensor with the starting points end (Tensor): the tensor with the ending points weight (float or tensor): the weight for the interpolation formula + +Keyword args: {out} Example:: @@ -3503,7 +3550,7 @@ def merge_dicts(*dicts): add_docstr(torch.lgamma, r""" -lgamma(input, out=None) -> Tensor +lgamma(input, *, out=None) -> Tensor Computes the logarithm of the gamma function on :attr:`input`. @@ -3521,9 +3568,10 @@ def merge_dicts(*dicts): tensor([ 0.5724, 0.0000, -0.1208]) """.format(**common_args)) +# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.linspace, r""" -linspace(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +linspace(start, end, steps=100, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a one-dimensional tensor of :attr:`steps` equally spaced points between :attr:`start` and :attr:`end`. @@ -3556,7 +3604,7 @@ def merge_dicts(*dicts): add_docstr(torch.log, r""" -log(input, out=None) -> Tensor +log(input, *, out=None) -> Tensor Returns a new tensor with the natural logarithm of the elements of :attr:`input`. @@ -3564,8 +3612,11 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{e} (x_{i}) """ + r""" + Args: {input} + +Keyword args: {out} Example:: @@ -3579,7 +3630,7 @@ def merge_dicts(*dicts): add_docstr(torch.log10, r""" -log10(input, out=None) -> Tensor +log10(input, *, out=None) -> Tensor Returns a new tensor with the logarithm to the base 10 of the elements of :attr:`input`. @@ -3587,8 +3638,11 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{10} (x_{i}) """ + r""" + Args: {input} + +Keyword args: {out} Example:: @@ -3605,7 +3659,7 @@ def merge_dicts(*dicts): add_docstr(torch.log1p, r""" -log1p(input, out=None) -> Tensor +log1p(input, *, out=None) -> Tensor Returns a new tensor with the natural logarithm of (1 + :attr:`input`). @@ -3617,6 +3671,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: {out} Example:: @@ -3630,7 +3686,7 @@ def merge_dicts(*dicts): add_docstr(torch.log2, r""" -log2(input, out=None) -> Tensor +log2(input, *, out=None) -> Tensor Returns a new tensor with the logarithm to the base 2 of the elements of :attr:`input`. @@ -3638,8 +3694,11 @@ def merge_dicts(*dicts): .. math:: y_{i} = \log_{2} (x_{i}) """ + r""" + Args: {input} + +Keyword args: {out} Example:: @@ -3656,7 +3715,7 @@ def merge_dicts(*dicts): add_docstr(torch.logaddexp, r""" -logaddexp(input, other, out=None) -> Tensor +logaddexp(input, other, *, out=None) -> Tensor Logarithm of the sum of exponentiations of the inputs. @@ -3688,7 +3747,7 @@ def merge_dicts(*dicts): add_docstr(torch.logaddexp2, r""" -logaddexp2(input, other, out=None) -> Tensor +logaddexp2(input, other, *, out=None) -> Tensor Logarithm of the sum of exponentiations of the inputs in base-2. @@ -3705,7 +3764,7 @@ def merge_dicts(*dicts): add_docstr(torch.logical_and, r""" -logical_and(input, other, out=None) -> Tensor +logical_and(input, other, *, out=None) -> Tensor Computes the element-wise logical AND of the given input tensors. Zeros are treated as ``False`` and nonzeros are treated as ``True``. @@ -3713,6 +3772,8 @@ def merge_dicts(*dicts): Args: {input} other (Tensor): the tensor to compute AND with + +Keyword args: {out} Example:: @@ -3733,13 +3794,15 @@ def merge_dicts(*dicts): add_docstr(torch.logical_not, r""" -logical_not(input, out=None) -> Tensor +logical_not(input, *, out=None) -> Tensor Computes the element-wise logical NOT of the given input tensor. If not specified, the output tensor will have the bool dtype. If the input tensor is not a bool tensor, zeros are treated as ``False`` and non-zeros are treated as ``True``. Args: {input} + +Keyword args: {out} Example:: @@ -3756,7 +3819,7 @@ def merge_dicts(*dicts): add_docstr(torch.logical_or, r""" -logical_or(input, other, out=None) -> Tensor +logical_or(input, other, *, out=None) -> Tensor Computes the element-wise logical OR of the given input tensors. Zeros are treated as ``False`` and nonzeros are treated as ``True``. @@ -3764,6 +3827,8 @@ def merge_dicts(*dicts): Args: {input} other (Tensor): the tensor to compute OR with + +Keyword args: {out} Example:: @@ -3784,7 +3849,7 @@ def merge_dicts(*dicts): add_docstr(torch.logical_xor, r""" -logical_xor(input, other, out=None) -> Tensor +logical_xor(input, other, *, out=None) -> Tensor Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are treated as ``True``. @@ -3792,6 +3857,8 @@ def merge_dicts(*dicts): Args: {input} other (Tensor): the tensor to compute XOR with + +Keyword args: {out} Example:: @@ -3810,10 +3877,12 @@ def merge_dicts(*dicts): tensor([ True, True, False, False]) """.format(**common_args)) +# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.logspace, - r""" -logspace(start, end, steps=100, base=10.0, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor - + """ +logspace(start, end, steps=100, base=10.0, *, \ + out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +""" + r""" Returns a one-dimensional tensor of :attr:`steps` points logarithmically spaced with base :attr:`base` between :math:`{{\text{{base}}}}^{{\text{{start}}}}` and :math:`{{\text{{base}}}}^{{\text{{end}}}}`. @@ -3846,7 +3915,7 @@ def merge_dicts(*dicts): add_docstr(torch.logsumexp, r""" -logsumexp(input, dim, keepdim=False, out=None) +logsumexp(input, dim, keepdim=False, *, out=None) Returns the log of summed exponentials of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. The computation is numerically @@ -3863,6 +3932,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: {out} @@ -3874,7 +3945,7 @@ def merge_dicts(*dicts): add_docstr(torch.lstsq, r""" -lstsq(input, A, out=None) -> Tensor +lstsq(input, A, *, out=None) -> Tensor Computes the solution to the least squares and least norm problems for a full rank matrix :math:`A` of size :math:`(m \times n)` and a matrix :math:`B` of @@ -3907,6 +3978,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the matrix :math:`B` A (Tensor): the :math:`m` by :math:`n` matrix :math:`A` + +Keyword args: out (tuple, optional): the optional destination tensor Returns: @@ -3969,7 +4042,7 @@ def merge_dicts(*dicts): add_docstr(torch.lu_solve, r""" -lu_solve(input, LU_data, LU_pivots, out=None) -> Tensor +lu_solve(input, LU_data, LU_pivots, *, out=None) -> Tensor Returns the LU solve of the linear system :math:`Ax = b` using the partially pivoted LU factorization of A from :meth:`torch.lu`. @@ -3983,6 +4056,8 @@ def merge_dicts(*dicts): where :math:`*` is zero or more batch dimensions. The batch dimensions of :attr:`LU_pivots` must be equal to the batch dimensions of :attr:`LU_data`. + +Keyword args: {out} Example:: @@ -3998,7 +4073,7 @@ def merge_dicts(*dicts): add_docstr(torch.masked_select, r""" -masked_select(input, mask, out=None) -> Tensor +masked_select(input, mask, *, out=None) -> Tensor Returns a new 1-D tensor which indexes the :attr:`input` tensor according to the boolean mask :attr:`mask` which is a `BoolTensor`. @@ -4012,6 +4087,8 @@ def merge_dicts(*dicts): Args: {input} mask (BoolTensor): the tensor containing the binary mask to index with + +Keyword args: {out} Example:: @@ -4159,7 +4236,7 @@ def merge_dicts(*dicts): >>> torch.max(a) tensor(0.7445) -.. function:: max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) +.. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum value of each row of the :attr:`input` tensor in the given dimension @@ -4182,6 +4259,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} Default: ``False``. + +Keyword args: out (tuple, optional): the result tuple of two output tensors (max, max_indices) Example:: @@ -4234,6 +4313,8 @@ def merge_dicts(*dicts): This is the second value returned by :meth:`torch.max`. See its documentation for the exact semantics of this method. +.. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + Args: {input} @@ -4289,7 +4370,7 @@ def merge_dicts(*dicts): >>> torch.mean(a) tensor(0.3367) -.. function:: mean(input, dim, keepdim=False, out=None) -> Tensor +.. function:: mean(input, dim, keepdim=False, *, out=None) -> Tensor Returns the mean value of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. If :attr:`dim` is a list of dimensions, @@ -4301,6 +4382,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: {out} Example:: @@ -4340,7 +4423,7 @@ def merge_dicts(*dicts): >>> torch.median(a) tensor(0.2202) -.. function:: median(input, dim=-1, keepdim=False, out=None) -> (Tensor, LongTensor) +.. function:: median(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the median value of each row of the :attr:`input` tensor in the given dimension @@ -4365,6 +4448,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: out (tuple, optional): the result tuple of two output tensors (max, max_indices) Example:: @@ -4463,7 +4548,7 @@ def merge_dicts(*dicts): >>> torch.min(a) tensor(0.6750) -.. function:: min(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) +.. function:: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the minimum value of each row of the :attr:`input` tensor in the given dimension @@ -4486,6 +4571,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: out (tuple, optional): the tuple of two output tensors (min, min_indices) Example:: @@ -4537,6 +4624,8 @@ def merge_dicts(*dicts): This is the second value returned by :meth:`torch.min`. See its documentation for the exact semantics of this method. +.. note:: If there are multiple minimal values then the indices of the first minimal value are returned. + Args: {input} @@ -4551,7 +4640,7 @@ def merge_dicts(*dicts): >>> torch.argmin(a) tensor(13) -.. function:: argmin(input, dim, keepdim=False, out=None) -> LongTensor +.. function:: argmin(input, dim, keepdim=False) -> LongTensor Returns the indices of the minimum values of a tensor across a dimension. @@ -4577,7 +4666,7 @@ def merge_dicts(*dicts): add_docstr(torch.mm, r""" -mm(input, mat2, out=None) -> Tensor +mm(input, mat2, *, out=None) -> Tensor Performs a matrix multiplication of the matrices :attr:`input` and :attr:`mat2`. @@ -4590,6 +4679,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the first matrix to be multiplied mat2 (Tensor): the second matrix to be multiplied + +Keyword args: {out} Example:: @@ -4603,7 +4694,7 @@ def merge_dicts(*dicts): add_docstr(torch.matmul, r""" -matmul(input, other, out=None) -> Tensor +matmul(input, other, *, out=None) -> Tensor Matrix product of two tensors. @@ -4633,6 +4724,8 @@ def merge_dicts(*dicts): Arguments: input (Tensor): the first tensor to be multiplied other (Tensor): the second tensor to be multiplied + +Keyword args: {out} Example:: @@ -4667,7 +4760,7 @@ def merge_dicts(*dicts): add_docstr(torch.mode, r""" -mode(input, dim=-1, keepdim=False, out=None) -> (Tensor, LongTensor) +mode(input, dim=-1, keepdim=False, *, out=None) -> (Tensor, LongTensor) Returns a namedtuple ``(values, indices)`` where ``values`` is the mode value of each row of the :attr:`input` tensor in the given dimension @@ -4687,6 +4780,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: out (tuple, optional): the result tuple of two output tensors (values, indices) Example:: @@ -4701,7 +4796,7 @@ def merge_dicts(*dicts): add_docstr(torch.mul, r""" -mul(input, other, out=None) +mul(input, other, *, out=None) Multiplies each element of the input :attr:`input` with the scalar :attr:`other` and returns a new resulting tensor. @@ -4714,7 +4809,9 @@ def merge_dicts(*dicts): Args: {input} - value (Number): the number to be multiplied to each element of :attr:`input` + other (Number): the number to be multiplied to each element of :attr:`input` + +Keyword args: {out} Example:: @@ -4725,7 +4822,7 @@ def merge_dicts(*dicts): >>> torch.mul(a, 100) tensor([ 20.1494, -42.5491, 260.8663]) -.. function:: mul(input, other, out=None) +.. function:: mul(input, other, *, out=None) Each element of the tensor :attr:`input` is multiplied by the corresponding element of the Tensor :attr:`other`. The resulting tensor is returned. @@ -4736,9 +4833,12 @@ def merge_dicts(*dicts): .. math:: \text{out}_i = \text{input}_i \times \text{other}_i """ + r""" + Args: input (Tensor): the first multiplicand tensor other (Tensor): the second multiplicand tensor + +Keyword args: {out} Example:: @@ -4794,6 +4894,8 @@ def merge_dicts(*dicts): input (Tensor): the input tensor containing probabilities num_samples (int): number of samples to draw replacement (bool, optional): whether to draw with replacement or not + +Keyword args: {generator} {out} @@ -4811,7 +4913,7 @@ def merge_dicts(*dicts): add_docstr(torch.mv, r""" -mv(input, vec, out=None) -> Tensor +mv(input, vec, *, out=None) -> Tensor Performs a matrix-vector product of the matrix :attr:`input` and the vector :attr:`vec`. @@ -4824,6 +4926,8 @@ def merge_dicts(*dicts): Args: input (Tensor): matrix to be multiplied vec (Tensor): vector to be multiplied + +Keyword args: {out} Example:: @@ -4961,7 +5065,7 @@ def merge_dicts(*dicts): add_docstr(torch.neg, r""" -neg(input, out=None) -> Tensor +neg(input, *, out=None) -> Tensor Returns a new tensor with the negative of the elements of :attr:`input`. @@ -4970,6 +5074,8 @@ def merge_dicts(*dicts): """ + r""" Args: {input} + +Keyword args: {out} Example:: @@ -5000,6 +5106,7 @@ def merge_dicts(*dicts): Args: input (Tensor): the first input tensor other (Tensor): the second input tensor + Keyword args: {out} @@ -5052,6 +5159,8 @@ def merge_dicts(*dicts): Args: {input} + +Keyword args: out (LongTensor, optional): the output tensor containing indices Returns: @@ -5108,6 +5217,8 @@ def merge_dicts(*dicts): Args: mean (Tensor): the tensor of per-element means std (Tensor): the tensor of per-element standard deviations + +Keyword args: {generator} {out} @@ -5117,7 +5228,7 @@ def merge_dicts(*dicts): tensor([ 1.0425, 3.5672, 2.7969, 4.2925, 4.7229, 6.2134, 8.0505, 8.1408, 9.0563, 10.0566]) -.. function:: normal(mean=0.0, std, out=None) -> Tensor +.. function:: normal(mean=0.0, std, *, out=None) -> Tensor Similar to the function above, but the means are shared among all drawn elements. @@ -5125,6 +5236,8 @@ def merge_dicts(*dicts): Args: mean (float, optional): the mean for all distributions std (Tensor): the tensor of per-element standard deviations + +Keyword args: {out} Example:: @@ -5132,7 +5245,7 @@ def merge_dicts(*dicts): >>> torch.normal(mean=0.5, std=torch.arange(1., 6.)) tensor([-1.2793, -1.0732, -2.0687, 5.1177, -1.2303]) -.. function:: normal(mean, std=1.0, out=None) -> Tensor +.. function:: normal(mean, std=1.0, *, out=None) -> Tensor Similar to the function above, but the standard-deviations are shared among all drawn elements. @@ -5140,6 +5253,8 @@ def merge_dicts(*dicts): Args: mean (Tensor): the tensor of per-element means std (float, optional): the standard deviation for all distributions + +Keyword args: out (Tensor, optional): the output tensor Example:: @@ -5156,6 +5271,8 @@ def merge_dicts(*dicts): mean (float): the mean for all distributions std (float): the standard deviation for all distributions size (int...): a sequence of integers defining the shape of the output tensor. + +Keyword args: {out} Example:: @@ -5184,9 +5301,10 @@ def merge_dicts(*dicts): """.format(**common_args)) +# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones, r""" -ones(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor +ones(*size, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor Returns a tensor filled with the scalar value `1`, with the shape defined by the variable argument :attr:`size`. @@ -5211,9 +5329,10 @@ def merge_dicts(*dicts): """.format(**factory_common_args)) +# TODO: see https://github.com/pytorch/pytorch/issues/43667 add_docstr(torch.ones_like, r""" -ones_like(input, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor +ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor Returns a tensor filled with the scalar value `1`, with the same size as :attr:`input`. ``torch.ones_like(input)`` is equivalent to @@ -5292,6 +5411,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the input tensor containing the rates of the Poisson distribution + +Keyword args: {generator} Example:: @@ -5306,7 +5427,7 @@ def merge_dicts(*dicts): add_docstr(torch.polygamma, r""" -polygamma(n, input, out=None) -> Tensor +polygamma(n, input, *, out=None) -> Tensor Computes the :math:`n^{th}` derivative of the digamma function on :attr:`input`. :math:`n \geq 0` is called the order of the polygamma function. @@ -5320,6 +5441,8 @@ def merge_dicts(*dicts): Args: n (int): the order of the polygamma function {input} + +Keyword args: {out} Example:: @@ -5336,7 +5459,7 @@ def merge_dicts(*dicts): add_docstr(torch.pow, r""" -pow(input, exponent, out=None) -> Tensor +pow(input, exponent, *, out=None) -> Tensor Takes the power of each element in :attr:`input` with :attr:`exponent` and returns a tensor with the result. @@ -5360,6 +5483,8 @@ def merge_dicts(*dicts): Args: {input} exponent (float or tensor): the exponent value + +Keyword args: {out} Example:: @@ -5379,7 +5504,7 @@ def merge_dicts(*dicts): >>> torch.pow(a, exp) tensor([ 1., 4., 27., 256.]) -.. function:: pow(self, exponent, out=None) -> Tensor +.. function:: pow(self, exponent, *, out=None) -> Tensor :attr:`self` is a scalar ``float`` value, and :attr:`exponent` is a tensor. The returned tensor :attr:`out` is of the same shape as :attr:`exponent` @@ -5392,6 +5517,8 @@ def merge_dicts(*dicts): Args: self (float): the scalar base value for the power operation exponent (Tensor): the exponent tensor + +Keyword args: {out} Example:: @@ -5404,12 +5531,14 @@ def merge_dicts(*dicts): add_docstr(torch.prod, r""" -prod(input, dtype=None) -> Tensor +prod(input, *, dtype=None) -> Tensor Returns the product of all elements in the :attr:`input` tensor. Args: {input} + +Keyword args: {dtype} Example:: @@ -5420,7 +5549,7 @@ def merge_dicts(*dicts): >>> torch.prod(a) tensor(0.6902) -.. function:: prod(input, dim, keepdim=False, dtype=None) -> Tensor +.. function:: prod(input, dim, keepdim=False, *, dtype=None) -> Tensor Returns the product of each row of the :attr:`input` tensor in the given dimension :attr:`dim`. @@ -5431,6 +5560,8 @@ def merge_dicts(*dicts): {input} {dim} {keepdim} + +Keyword args: {dtype} Example:: @@ -6576,6 +6707,40 @@ def merge_dicts(*dicts): (tensor([0.9110, 0.8197, 1.2552, 1.0608]), tensor([-0.6871, 0.6229, 0.2169, -0.9058])) """.format(**multi_dim_common)) +add_docstr(torch.sub, r""" +sub(input, other, *, alpha=1, out=None) -> Tensor + +Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. + +.. math:: + \text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i +""" + r""" + +Supports :ref:`broadcasting to a common shape `, +:ref:`type promotion `, and integer, float, and complex inputs. + +Args: + {input} + other (Tensor or Scalar): the tensor or scalar to subtract from :attr:`input` + +Keyword args: + alpha (Scalar): the scalar multiplier for :attr:`other` + {out} + +Example:: + + >>> a = torch.tensor((1, 2)) + >>> b = torch.tensor((0, 1)) + >>> torch.sub(a, b, alpha=2) + tensor([1, 0]) +""".format(**common_args)) + +add_docstr(torch.subtract, r""" +subtract(input, other, *, alpha=1, out=None) -> Tensor + +Alias for :func:`torch.sub`. +""") + add_docstr(torch.sum, r""" sum(input, dtype=None) -> Tensor @@ -8763,7 +8928,7 @@ def merge_dicts(*dicts): r""" quantize_per_tensor(input, scale, zero_point, dtype) -> Tensor -Converts a float tensor to quantized tensor with given scale and zero point. +Converts a float tensor to a quantized tensor with given scale and zero point. Arguments: input (Tensor): float tensor to quantize @@ -8788,7 +8953,7 @@ def merge_dicts(*dicts): r""" quantize_per_channel(input, scales, zero_points, axis, dtype) -> Tensor -Converts a float tensor to per-channel quantized tensor with given scales and zero points. +Converts a float tensor to a per-channel quantized tensor with given scales and zero points. Arguments: input (Tensor): float tensor to quantize diff --git a/torch/_utils.py b/torch/_utils.py index c63e30955f015c8..11f378a4d7f9b95 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -421,6 +421,10 @@ def reraise(self): # makes stack traces unreadable. It will not be changed in Python # (https://bugs.python.org/issue2651), so we work around it. msg = KeyErrorMessage(msg) + elif getattr(self.exc_type, "message", None): + # Some exceptions have first argument as non-str but explicitly + # have message field + raise self.exc_type(message=msg) raise self.exc_type(msg) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index baa4b6363deb48d..e1b303129080d8f 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -1,6 +1,7 @@ import torch import functools import inspect +from typing import Any class _DecoratorContextManager: """Allow a context manager to be used as a decorator""" @@ -60,11 +61,16 @@ class no_grad(_DecoratorContextManager): >>> z.requires_grad False """ + def __init__(self): + if not torch._jit_internal.is_scripting(): + super().__init__() + self.prev = False + def __enter__(self): self.prev = torch.is_grad_enabled() - torch._C.set_grad_enabled(False) + torch.set_grad_enabled(False) - def __exit__(self, *args): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): torch.set_grad_enabled(self.prev) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index b05911c08d33498..04b547145ec04c3 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -61,6 +61,25 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3): x_tensors = iter_tensors(target, True) j_tensors = iter_tensors(jacobian) + def compute_gradient(x, idx, is_mkldnn=False): + + def fn_out(): + if not is_mkldnn: + # x is a view into input and so this works + return fn(input).clone() + else: + # convert the dense tensor back to have mkldnn layout + return fn([x.to_mkldnn()]) + + orig = x[idx].item() + x[idx] = orig - eps + outa = fn_out() + x[idx] = orig + eps + outb = fn_out() + x[idx] = orig + r = (outb - outa) / (2 * eps) + return r.detach().reshape(-1) + # TODO: compare structure for x_tensor, d_tensor in zip(x_tensors, j_tensors): is_complex = x_tensor.dtype.is_complex @@ -90,14 +109,7 @@ def get_stride(size): for x_idx in product(*[range(m) for m in x_values.size()[1:]]): indices = x_indices[i].tolist() + list(x_idx) d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size))) - orig = x_value[x_idx].item() - x_value[x_idx] = orig - eps - outa = fn(input).clone() - x_value[x_idx] = orig + eps - outb = fn(input).clone() - x_value[x_idx] = orig - r = (outb - outa) / (2 * eps) - d_tensor[d_idx] = r.detach().reshape(-1) + d_tensor[d_idx] = compute_gradient(x_value, x_idx) elif x_tensor.layout == torch._mkldnn: # Use .data here to get around the version check x_tensor = x_tensor.data @@ -108,30 +120,12 @@ def get_stride(size): # this is really inefficient, but without indexing implemented, there's # not really a better way than converting back and forth x_tensor_dense = x_tensor.to_dense() - orig = x_tensor_dense[x_idx].item() - - x_tensor_dense[x_idx] = orig - eps - x_tensor_mkl = x_tensor_dense.to_mkldnn() - outa = fn([x_tensor_mkl]) - - x_tensor_dense[x_idx] = orig + eps - x_tensor_mkl = x_tensor_dense.to_mkldnn() - outb = fn([x_tensor_mkl]) - - r = (outb - outa) / (2 * eps) - d_tensor[d_idx] = r.detach().reshape(-1) + d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, is_mkldnn=True) else: # Use .data here to get around the version check x_tensor = x_tensor.data for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): - orig = x_tensor[x_idx].item() - x_tensor[x_idx] = orig - eps - outa = fn(input).clone() - x_tensor[x_idx] = orig + eps - outb = fn(input).clone() - x_tensor[x_idx] = orig - r = (outb - outa) / (2 * eps) - d_tensor[d_idx] = r.detach().reshape(-1) + d_tensor[d_idx] = compute_gradient(x_tensor, x_idx) return jacobian diff --git a/torch/csrc/api/include/torch/all.h b/torch/csrc/api/include/torch/all.h index 5717bccf6017e54..5bcc8eec93abd34 100644 --- a/torch/csrc/api/include/torch/all.h +++ b/torch/csrc/api/include/torch/all.h @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h index d1c4c60df6f207f..5ce90dcc972e543 100644 --- a/torch/csrc/api/include/torch/linalg.h +++ b/torch/csrc/api/include/torch/linalg.h @@ -12,6 +12,22 @@ inline Tensor det(const Tensor& self) { return torch::linalg_det(self); } +inline Tensor norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -21,4 +37,20 @@ inline Tensor linalg_det(const Tensor& self) { return detail::det(self); } +inline Tensor linalg_norm(const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor linalg_norm(const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, optional opt_ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); +} + +inline Tensor& linalg_norm_out(Tensor& result, const Tensor& self, std::string ord, optional opt_dim, bool keepdim, optional opt_dtype) { + return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); +} + }} // torch::linalg diff --git a/torch/csrc/api/include/torch/nn/modules/transformercoder.h b/torch/csrc/api/include/torch/nn/modules/transformercoder.h index 44ae9928177fb8a..e022824392c664e 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/modules/transformercoder.h @@ -67,5 +67,73 @@ class TORCH_API TransformerEncoderImpl : public Cloneable { + public: + TransformerDecoderImpl(TransformerDecoderLayer decoder_layer, int64_t num_layers) + : TransformerDecoderImpl(TransformerDecoderOptions(decoder_layer, num_layers)) {} + explicit TransformerDecoderImpl(TransformerDecoderOptions options_); + + void reset() override; + + void reset_parameters(); + + /// Pass the inputs (and mask) through the decoder layer in turn. + ///Args: + /// tgt: the sequence to the decoder layer (required). + /// memory: the sequence from the last layer of the encoder (required). + /// tgt_mask: the mask for the tgt sequence (optional). + /// memory_mask: the mask for the memory sequence (optional). + /// tgt_key_padding_mask: the mask for the tgt keys per batch (optional). + /// memory_key_padding_mask: the mask for the memory keys per batch (optional). + Tensor forward(const Tensor& tgt, + const Tensor& memory, + const Tensor& tgt_mask = {}, + const Tensor& memory_mask = {}, + const Tensor& tgt_key_padding_mask = {}, + const Tensor& memory_key_padding_mask = {}); + + /// The options used to configure this module. + TransformerDecoderOptions options; + + ///Cloned layers of decoder layers + ModuleList layers{nullptr}; + + ///optional layer normalization module + AnyModule norm; + + protected: + FORWARD_HAS_DEFAULT_ARGS( + {2, AnyValue(Tensor())}, + {3, AnyValue(Tensor())}, + {4, AnyValue(Tensor())}, + {5, AnyValue(Tensor())}) + + }; + +/// A `ModuleHolder` subclass for `TransformerDecoderImpl`. +/// See the documentation for `TransformerDecoderImpl` class to learn what methods it +/// provides, and examples of how to use `TransformerDecoder` with +/// `torch::nn::TransformerDecoderOptions`. +/// See the documentation for `ModuleHolder` to learn about PyTorch's +/// module storage semantics. +TORCH_MODULE(TransformerDecoder); + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/include/torch/nn/options/transformercoder.h b/torch/csrc/api/include/torch/nn/options/transformercoder.h index 9a0a2bc5d01c49c..6c5975ef5c12acc 100644 --- a/torch/csrc/api/include/torch/nn/options/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/options/transformercoder.h @@ -34,5 +34,34 @@ namespace nn { TORCH_ARG(AnyModule, norm); }; +/// Options for the `TransformerDecoder` module. +/// +/// Example: +/// ``` +/// TransformerDecoderLayer decoder_layer(TransformerDecoderLayerOptions(512, 8).dropout(0.1)); +/// auto options = TransformerDecoderOptions(decoder_layer, 6)norm(LayerNorm(LayerNormOptions({2}))); +/// TransformerDecoder transformer_decoder(options); +/// ``` +struct TORCH_API TransformerDecoderOptions { + // This constructor will keep the a ref of passed in decoder_layer, + // so it keeps all the data in decoder_layer. + TransformerDecoderOptions(TransformerDecoderLayer decoder_layer, + int64_t num_layers); + // This constructor will create a new TransformerDecoderLayer obj, + // based on passed in decoder_layer_options. + TransformerDecoderOptions( + const TransformerDecoderLayerOptions& decoder_layer_options, + int64_t num_layers); + + /// decoder layer to be cloned + TORCH_ARG(TransformerDecoderLayer, decoder_layer) = nullptr; + + /// number of decoder layers + TORCH_ARG(int64_t, num_layers); + + /// normalization module + TORCH_ARG(AnyModule, norm); +}; + } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index 76183c7b71aba9f..20c43af21028b1f 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -155,7 +155,9 @@ void TransformerDecoderLayerImpl::reset_parameters() { } ///Pass the inputs (and mask) through the decoder layer. -Tensor TransformerDecoderLayerImpl::forward(Tensor tgt, const Tensor& memory, +Tensor TransformerDecoderLayerImpl::forward( + Tensor tgt, + const Tensor& memory, const Tensor& tgt_mask, const Tensor& memory_mask, const Tensor& tgt_key_padding_mask, @@ -262,6 +264,83 @@ Tensor TransformerEncoderImpl::forward( return output; } +// ========================TransformerDecoderImpl========================= +TransformerDecoderImpl::TransformerDecoderImpl( + TransformerDecoderOptions options_ ) : options(std::move(options_)){ + reset(); +} + +void TransformerDecoderImpl::reset() { + + layers = this->register_module("layers", ModuleList()); + for (int64_t i = 0; i < options.num_layers(); ++i) { + layers->push_back(options.decoder_layer()->clone()); + } + + if (!options.norm().is_empty()) { + norm = options.norm().clone(); + this->register_module("norm", norm.ptr()); + } +} + +void TransformerDecoderImpl::reset_parameters() { + + TORCH_CHECK(layers->size() == options.num_layers(), + "TransformerDecoder should have", options.num_layers(), + " decoder layers, but got ", layers->size()); + + size_t num_layers = layers->size(); + for (size_t i = 0; i < num_layers; ++i) { + layers->at(i).reset_parameters(); + } + // a. No way to know whether module in AnyModule has api to reset_parameters, so replace instead + // b. Allow user to add/delete normalization module when reset parameters + if (!norm.is_empty()) { + this->unregister_module("norm"); + norm = AnyModule(); + } + if (!options.norm().is_empty()) { + norm = options.norm().clone(); + this->register_module("norm", norm.ptr()); + } + +} + +Tensor TransformerDecoderImpl::forward( + const Tensor& tgt, + const Tensor& memory, + const Tensor& tgt_mask, + const Tensor& memory_mask, + const Tensor& tgt_key_padding_mask, + const Tensor& memory_key_padding_mask){ + + size_t num_layers = layers->size(); + Tensor output; + if (num_layers > 0) { + output = layers->at(0).forward( + tgt, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask); + } + for (size_t i = 1; i < num_layers; ++i) { + output = layers->at(i).forward( + output, + memory, + tgt_mask, + memory_mask, + tgt_key_padding_mask, + memory_key_padding_mask); + } + + if (!norm.is_empty()) { + output = norm.forward(num_layers == 0 ? tgt : output); + } + + return output; +} } // namespace nn } // namespace torch diff --git a/torch/csrc/api/src/nn/options/transformer.cpp b/torch/csrc/api/src/nn/options/transformer.cpp index 766c34ce6583f85..fd83aa81a31e240 100644 --- a/torch/csrc/api/src/nn/options/transformer.cpp +++ b/torch/csrc/api/src/nn/options/transformer.cpp @@ -8,8 +8,8 @@ TransformerEncoderLayerOptions::TransformerEncoderLayerOptions( int64_t d_model, int64_t nhead) : d_model_(d_model), nhead_(nhead) {} -TransformerDecoderLayerOptions::TransformerDecoderLayerOptions(int64_t d_model, int64_t nhead) -: d_model_(d_model), nhead_(nhead){} +TransformerDecoderLayerOptions::TransformerDecoderLayerOptions( + int64_t d_model, int64_t nhead) : d_model_(d_model), nhead_(nhead){} TransformerEncoderOptions::TransformerEncoderOptions( @@ -21,5 +21,16 @@ TransformerEncoderOptions::TransformerEncoderOptions( const TransformerEncoderLayerOptions& encoder_layer_options, int64_t num_layers) : encoder_layer_(encoder_layer_options), num_layers_(num_layers) {} + +TransformerDecoderOptions::TransformerDecoderOptions( + TransformerDecoderLayer decoder_layer, int64_t num_layers) : + decoder_layer_(std::move(decoder_layer)), num_layers_(num_layers) {} + + +TransformerDecoderOptions::TransformerDecoderOptions( + const TransformerDecoderLayerOptions& decoder_layer_options, + int64_t num_layers) + : decoder_layer_(decoder_layer_options), num_layers_(num_layers){} + } // namespace nn } // namespace torch diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index 12773aafcf96f7a..692972533adcceb 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -55,6 +55,12 @@ inline void check_inplace(const Tensor& tensor) { } } +inline void check_inplace(const TensorList tensors) { + for (const auto& tensor : tensors) { + check_inplace(tensor); + } +} + inline void throw_error_out_requires_grad(const char* name) { AT_ERROR( name, "(): functions with out=... arguments don't support automatic differentiation, " diff --git a/torch/csrc/autograd/edge.h b/torch/csrc/autograd/edge.h index 3d055acc8b0f73a..76e6baf8aa38741 100644 --- a/torch/csrc/autograd/edge.h +++ b/torch/csrc/autograd/edge.h @@ -4,7 +4,7 @@ #include #include -#include +#include namespace torch { namespace autograd { @@ -50,7 +50,7 @@ struct hash { using argument_type = torch::autograd::Edge; using return_type = size_t; return_type operator()(const argument_type& edge) const noexcept { - return torch::get_hash(edge.function, edge.input_nr); + return c10::get_hash(edge.function, edge.input_nr); } }; } // namespace std diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index dbfad76c05123c6..38f477e9ef988a2 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -454,8 +454,8 @@ void GraphTask::mark_as_completed_and_run_post_processing() { if (future_completed_.exchange(true)) { // Future is already marked complete, or being marked as such. // In case the marking complete is only in progress, we add a - // waitNoThrow() to guarantee the future is marked complete on exit. - future_result_->waitNoThrow(); + // wait() to guarantee the future is marked complete on exit. + future_result_->wait(); return; } @@ -871,7 +871,13 @@ auto Engine::execute(const edge_list& roots, graph_task->init_to_execute(*graph_root, outputs); } - return execute_with_graph_task(graph_task, graph_root)->wait(); + execute_with_graph_task(graph_task, graph_root); + // Avoid a refcount bump for the Future, since we check for refcount in + // DistEngine (see TORCH_INTERNAL_ASSERT(futureGrads.use_count() == 1) + // in dist_engine.cpp). + auto& fut = graph_task->future_result_; + fut->wait(); + return fut->value().toTensorVector(); } void Engine::initialize_device_threads_pool() { @@ -882,7 +888,7 @@ void Engine::initialize_device_threads_pool() { std::call_once(start_device_threads_flag_, &Engine::start_device_threads, this); } -std::shared_ptr Engine::execute_with_graph_task( +std::shared_ptr Engine::execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root) { initialize_device_threads_pool(); diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index c4bd5976f1b9cae..f72bdec36a9df58 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -4,13 +4,13 @@ // to "root" variables (variables created by the user with requires_grad=True). #include +#include #include #include #include #include #include #include -#include #include #include @@ -28,8 +28,6 @@ struct ReadyQueue; namespace torch { namespace autograd { -using FutureVariableList = torch::utils::Future; - static constexpr int NO_DEVICE = -2; static constexpr int CPU_DEVICE = -1; @@ -153,7 +151,7 @@ struct GraphTask: std::enable_shared_from_this { // Future representing the completion of the graph task. Notified when all // tasks are done. - std::shared_ptr future_result_; + std::shared_ptr future_result_; // Final callbacks installed during execution of this GraphTask std::vector> final_callbacks_; @@ -173,7 +171,7 @@ struct GraphTask: std::enable_shared_from_this { reentrant_depth_(reentrant_depth), exit_on_error_(exit_on_error), cpu_ready_queue_(std::move(cpu_ready_queue)), - future_result_(std::make_shared()) {} + future_result_(std::make_shared(c10::ListType::create(c10::TensorType::get()))) {} private: // run GraphTask post processing void exec_post_processing(); @@ -281,7 +279,7 @@ struct TORCH_API Engine { // // NB: This API should only be used by internal autograd specific // machinery and shouldn't be exposed to users in anyway. - virtual std::shared_ptr execute_with_graph_task( + virtual std::shared_ptr execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root); diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index fd8de5c05e76e62..f4c88225efc804b 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -99,7 +99,7 @@ variable_list PythonEngine::execute( } } -std::shared_ptr PythonEngine::execute_with_graph_task( +std::shared_ptr PythonEngine::execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root) { try { diff --git a/torch/csrc/autograd/python_engine.h b/torch/csrc/autograd/python_engine.h index 2585ae3ad5df499..7d722d43d504ce4 100644 --- a/torch/csrc/autograd/python_engine.h +++ b/torch/csrc/autograd/python_engine.h @@ -25,7 +25,7 @@ struct PythonEngine : public Engine { bool create_graph, const edge_list& outputs = {}) override; - std::shared_ptr execute_with_graph_task( + std::shared_ptr execute_with_graph_task( const std::shared_ptr& graph_task, std::shared_ptr graph_root) override; diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index b18cb3b1acca5b8..6cef307c7ccebe9 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -1,11 +1,11 @@ #include #include #include -#include #include #include #include +#include #include @@ -158,7 +158,7 @@ struct NcclCommList { using device_list = std::vector; // accesses to this object have to be guarded by THC's CudaFreeMutex -static std::unordered_map> +static std::unordered_map> _communicators; ArrayRef get_communicators(TensorList inputs) { diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index 691c546f4b2ad21..ccf373224c72277 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -13,7 +13,6 @@ namespace autograd { using torch::autograd::AccumulateGrad; using torch::autograd::edge_list; using torch::autograd::Engine; -using torch::autograd::FutureVariableList; using torch::autograd::GraphRoot; using torch::autograd::GraphTask; using torch::autograd::GraphTaskGuard; @@ -390,9 +389,8 @@ std::shared_ptr DistEngine::runEngineAndAccumulateGradients( // future that waits for all gradient accumulation to finish. auto accumulateGradFuture = std::make_shared(); - futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture]( - const FutureVariableList& futureGrads) { - if (futureGrads.hasError()) { + futureGrads->addCallback([autogradContext, outputEdges, accumulateGradFuture, &futureGrads]() { + if (futureGrads->hasError()) { // Don't accumulate gradients if we receive an error. // We must add the node information here since DistEngine::execute // waits on accumulateGradFuture and will throw an exception once we @@ -401,13 +399,13 @@ std::shared_ptr DistEngine::runEngineAndAccumulateGradients( "Error on Node ", DistAutogradContainer::getInstance().getWorkerId(), ": ", - futureGrads.error()->what()); + futureGrads->error()->what()); accumulateGradFuture->setError(errorMsg); return; } try { - const variable_list& grads = futureGrads.constValue(); + const variable_list& grads = futureGrads->constValue().toTensorVector(); TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size()); accumulateGradFuture->markCompleted(rpc::Message()); } catch (std::exception& e) { diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index cc35d08c5c0da41..6d327b239c75c85 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -749,10 +749,10 @@ They are used in specifying strategies for reduction collectives, e.g., execution and it does not wait for the entire operation to complete on GPU. Note that ``FutureNCCL`` does not support ``NCCL_BLOCKING_WAIT`` flag or NCCL's ``barrier()``. In addition, if a callback function was added by ``fut.then()``, it will wait until - ``WorkNCCL``'s NCCL streams synchronize with a new stream from device's stream pool and - invoke the callback inline after running the callback on the new stream. ``fut.then()`` - will return another ``FutureNCCL`` that holds the return value of the callback and the - stream that runs the callback. + ``WorkNCCL``'s NCCL streams synchronize with ``ProcessGroupNCCL``'s dedicated callback + stream and invoke the callback inline after running the callback on the callback stream. + ``fut.then()`` will return another ``FutureNCCL`` that holds the return value of the + callback and a ``CUDAEvent`` that recorded the callback stream. Note that ``fut.done()`` returns if the enire operation is completed on the GPU. )"); diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 46b54e6e5462093..ccfa0f37a32fcd0 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -12,7 +13,6 @@ #include #include #include -#include #include namespace c10d { @@ -1184,7 +1184,7 @@ struct BucketKey { // See torch/csrc/utils/hash.h for dispatch code. static size_t hash(const BucketKey& key) { - return torch::get_hash(key.type, key.device); + return c10::get_hash(key.type, key.device); } }; @@ -1220,7 +1220,7 @@ std::vector> compute_bucket_assignment_by_size( std::unordered_map< BucketKey, std::vector::const_iterator, - torch::hash> + c10::hash> bucket_size_limit_iterators; // Local accumulator type for a single bucket. @@ -1230,7 +1230,7 @@ std::vector> compute_bucket_assignment_by_size( }; // Keep vector of indices and size accumulator by tensor type and device. - std::unordered_map> + std::unordered_map> buckets; for (size_t i = 0; i < tensors.size(); i++) { diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index 503c8c2f393c340..79a8574eb392fc1 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -95,7 +95,7 @@ PyObject* rpc_init(PyObject* /* unused */) { // unqualified "hash" function call. However the // argument-dependent lookup for the function "hash" doesn't get // triggered in this context because it conflicts with the struct - // torch::hash, so we need to use the qualified name + // c10::hash, so we need to use the qualified name // py::detail::hash, which unfortunately is in a detail namespace. .def(py::detail::hash(py::self)) // NOLINT .def("__repr__", [](const WorkerInfo& workerInfo) { diff --git a/torch/csrc/jit/codegen/fuser/arg_spec.h b/torch/csrc/jit/codegen/fuser/arg_spec.h index 3a73f01961267a3..d6a5ca9816420b9 100644 --- a/torch/csrc/jit/codegen/fuser/arg_spec.h +++ b/torch/csrc/jit/codegen/fuser/arg_spec.h @@ -1,9 +1,9 @@ #pragma once #include #include // fmap +#include #include #include -#include #include #include @@ -20,7 +20,7 @@ namespace fuser { struct TORCH_API ArgSpec { ArgSpec(at::TensorList inputs, const int _device) : descs_{c10::fmap(inputs)}, - hash_code_{torch::get_hash(_device, inputs.size(), descs_)}, + hash_code_{c10::get_hash(_device, inputs.size(), descs_)}, device_{_device} {} // (Common) hash function diff --git a/torch/csrc/jit/codegen/fuser/compiler.cpp b/torch/csrc/jit/codegen/fuser/compiler.cpp index 4bedcae289820ee..e49a6a692345737 100644 --- a/torch/csrc/jit/codegen/fuser/compiler.cpp +++ b/torch/csrc/jit/codegen/fuser/compiler.cpp @@ -145,7 +145,7 @@ static std::vector getInputDependencies(const Value* output) { } static void setInputBroadcastGroups(KernelSpec& spec) { - std::unordered_set, torch::hash>> + std::unordered_set, c10::hash>> broadcast_groups; for (const Value* output : (spec.graph())->outputs()) { if (output->node()->kind() == prim::FusedConcat) { diff --git a/torch/csrc/jit/codegen/fuser/kernel_spec.h b/torch/csrc/jit/codegen/fuser/kernel_spec.h index a1d3f341bd6e940..50b3f5783a7ac57 100644 --- a/torch/csrc/jit/codegen/fuser/kernel_spec.h +++ b/torch/csrc/jit/codegen/fuser/kernel_spec.h @@ -139,7 +139,7 @@ struct TORCH_API KernelSpec { bool has_random_; mutable std::mutex mutex_; mutable std:: - unordered_map, torch::hash> + unordered_map, c10::hash> kernels_; }; diff --git a/torch/csrc/jit/codegen/fuser/tensor_desc.h b/torch/csrc/jit/codegen/fuser/tensor_desc.h index ee1011431c9f1a4..200191451f8373e 100644 --- a/torch/csrc/jit/codegen/fuser/tensor_desc.h +++ b/torch/csrc/jit/codegen/fuser/tensor_desc.h @@ -3,8 +3,8 @@ #include #include #include +#include #include -#include #include #include @@ -79,7 +79,7 @@ struct TORCH_API TensorDesc { } static size_t hash(const TensorDesc& spec) { - return torch::get_hash( + return c10::get_hash( spec.scalar_type, spec.nDim_, std::hash>{}(spec.contiguity)); diff --git a/torch/csrc/jit/frontend/ir_emitter.cpp b/torch/csrc/jit/frontend/ir_emitter.cpp index 59e6da9fb0f14a8..06d9d7077b5197d 100644 --- a/torch/csrc/jit/frontend/ir_emitter.cpp +++ b/torch/csrc/jit/frontend/ir_emitter.cpp @@ -57,7 +57,7 @@ struct Refinement { struct RefinementSet { // When a comparison like x is None is made, we associate type refinements // with its true value and its false value. If a boolean that has refinements - // associated with it is used in a conditional of an if statememt, the true + // associated with it is used in a conditional of an if statement, the true // and false refinements are inserted into the corresponding blocks using Refinements = std::vector; @@ -205,7 +205,7 @@ static std::shared_ptr makeMagic( // The Environment keeps track of two tables, one for values which are not first // class and a type table for values which are. When a first class value // is set in the environment, we emit a prim::Store which sets the -// name of the variable to approriate type, and when a first-class value is +// name of the variable to appropriate type, and when a first-class value is // referenced we emit a prim::Load that generates a value of the appropriate // type. // @@ -700,7 +700,7 @@ struct to_ir { def.range(), Expr(Compound::create(TK_NONE, def.range(), {})))); } else { // if we haven't seen any return statements, but the graph block exits - // (the funciton always throws) then we accept the declared return type if + // (the function always throws) then we accept the declared return type if // it exists or set it to none if (def_stack_.back().merged_return_type_ == nullptr) { def_stack_.back().merged_return_type_ = @@ -1414,7 +1414,7 @@ struct to_ir { // the scope of the if statement (all variables are scoped to the function). // Script is a subset of python: we consider variables to be in scope // as long as there is a definition of the variable along all paths - // through the if statemnent + // through the if statement // ---- // if ...: // a = diff --git a/torch/csrc/jit/ir/alias_analysis.cpp b/torch/csrc/jit/ir/alias_analysis.cpp index ea20a8ace4dad2c..de14f9d7c05eb9b 100644 --- a/torch/csrc/jit/ir/alias_analysis.cpp +++ b/torch/csrc/jit/ir/alias_analysis.cpp @@ -476,6 +476,7 @@ void AliasDb::analyzeImpl(Node* node) { case prim::CudaFusionGroup: case prim::FunctionalGraph: case prim::DifferentiableGraph: + case prim::FallbackGraph: return analyzeSubgraph(node); case prim::fork: return analyzeFork(node); @@ -522,6 +523,7 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeBroadcastingChunk(node); case prim::SetAttr: return analyzeSetAttr(node); + case prim::profile_optional: case prim::profile: makePointerTo(node->output(), node->inputs().at(0)); return; diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index b8f097b137374ec..be42aaf7c1c9051 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -1973,10 +1973,21 @@ void ProfileOp::cloneFrom(Node* other_) { auto other = other_->cast(); this->callback_ = other->getCallback(); } + Node* ProfileOp::allocNewInstance(Graph* g) { return new ProfileOp(g, {nullptr}); } +void ProfileOptionalOp::cloneFrom(Node* other_) { + Node::cloneFrom(other_); + auto other = other_->cast(); + this->callback_ = other->getCallback(); +} + +Node* ProfileOptionalOp::allocNewInstance(Graph* g) { + return new ProfileOptionalOp(g, {nullptr}); +} + TypePtr NamedValue::type() const { if (value_) { return value_->type(); @@ -1986,6 +1997,7 @@ TypePtr NamedValue::type() const { } constexpr Symbol ProfileOp::Kind; +constexpr Symbol ProfileOptionalOp::Kind; OperatorSet::OperatorSet(std::initializer_list sig_literals) { for (const char* sig : sig_literals) { diff --git a/torch/csrc/jit/ir/ir.h b/torch/csrc/jit/ir/ir.h index cb1f0662fa5e440..665bd9797b26bc0 100644 --- a/torch/csrc/jit/ir/ir.h +++ b/torch/csrc/jit/ir/ir.h @@ -435,6 +435,12 @@ struct TORCH_API Node { bool isNondeterministic() const; bool hasSideEffects() const; + // instructions lowered by the interpreter and not run in the optimized graph + bool notExecutedOp() const { + return kind_ == prim::Constant || kind_ == prim::profile || + kind_ == prim::profile_optional; + } + // Graphs // Note [Topological invariant] @@ -1337,6 +1343,28 @@ struct ProfileOp : public Node { std::function&)> callback_; }; +struct TORCH_API ProfileOptionalOp : public Node { + static constexpr Symbol Kind = ::c10::prim::profile_optional; + ProfileOptionalOp( + Graph* graph, + std::function&)> callback) + : Node(graph, ::c10::prim::profile_optional), callback_(callback) {} + + void cloneFrom(Node* other_) override; + Node* allocNewInstance(Graph* g) override; + + const std::function&)>& getCallback() const { + return callback_; + } + + void setCallback(std::function&)> callback) { + callback_ = callback; + } + + private: + std::function&)> callback_; +}; + // execute a Python function, used for Ops we can't optimize but that we want to // optimize around // diff --git a/torch/csrc/jit/ir/irparser.cpp b/torch/csrc/jit/ir/irparser.cpp index ffa65c06ecba8ee..4c4ce31d3b97e38 100644 --- a/torch/csrc/jit/ir/irparser.cpp +++ b/torch/csrc/jit/ir/irparser.cpp @@ -191,9 +191,9 @@ void IRParser::parseAttr(Node* n) { if (L.cur().kind == '[') { // list AttributeKind k = AttributeKind::ts; - std::vector is; - std::vector ss; - std::vector fs; + c10::List is; + c10::List ss; + c10::List fs; int elem_num = 0; parseList('[', ',', ']', [&] { ParsedLiteral r = parseScalarLiteral(n); @@ -219,16 +219,16 @@ void IRParser::parseAttr(Node* n) { }); switch (k) { case AttributeKind::ts: - n->ts_(Symbol::attr(attrname), {}); + n->ival_(Symbol::attr(attrname), IValue()); break; case AttributeKind::ss: - n->ss_(Symbol::attr(attrname), ss); + n->ival_(Symbol::attr(attrname), IValue(ss)); break; case AttributeKind::fs: - n->fs_(Symbol::attr(attrname), fs); + n->ival_(Symbol::attr(attrname), IValue(fs)); break; case AttributeKind::is: - n->is_(Symbol::attr(attrname), is); + n->ival_(Symbol::attr(attrname), IValue(is)); break; default: throw ErrorReport(L.cur().range) << "Unexpected attr type"; diff --git a/torch/csrc/jit/ir/node_hashing.cpp b/torch/csrc/jit/ir/node_hashing.cpp index 6524d325fa454a2..52cace15075ffe5 100644 --- a/torch/csrc/jit/ir/node_hashing.cpp +++ b/torch/csrc/jit/ir/node_hashing.cpp @@ -6,9 +6,9 @@ #include #include #include +#include #include #include -#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/ir/type_hashing.cpp b/torch/csrc/jit/ir/type_hashing.cpp index b544fa0fdcc77f7..a03f6508216fdea 100644 --- a/torch/csrc/jit/ir/type_hashing.cpp +++ b/torch/csrc/jit/ir/type_hashing.cpp @@ -2,8 +2,8 @@ #include #include #include +#include #include -#include namespace torch { namespace jit { diff --git a/torch/csrc/jit/passes/constant_propagation.cpp b/torch/csrc/jit/passes/constant_propagation.cpp index d34d2634b034f90..c3285f2e24265d4 100644 --- a/torch/csrc/jit/passes/constant_propagation.cpp +++ b/torch/csrc/jit/passes/constant_propagation.cpp @@ -95,6 +95,7 @@ std::unordered_set skip_list = { prim::Uninitialized, prim::Guard, prim::profile, + prim::profile_optional, prim::unchecked_unwrap_optional, // TODO remove // TODO (zach): we should consider skipping tensor factories in the cases // where the constant tensor would be large but cheap to create. diff --git a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp index 7486aea6ab6805f..11bee519292c5f3 100644 --- a/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp +++ b/torch/csrc/jit/passes/create_autodiff_subgraphs.cpp @@ -228,8 +228,7 @@ class SubgraphSlicer { size_t i = 0; for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end(); ++it) { - // constants are not interpreted as instructions, ignore them - i += it->kind() != prim::Constant; + i += !it->notExecutedOp(); if (i >= minSubgraphSize_) { return false; } diff --git a/torch/csrc/jit/passes/graph_rewrite_helper.cpp b/torch/csrc/jit/passes/graph_rewrite_helper.cpp index cba5d2dcd5ad636..90caf4690b8d301 100644 --- a/torch/csrc/jit/passes/graph_rewrite_helper.cpp +++ b/torch/csrc/jit/passes/graph_rewrite_helper.cpp @@ -78,6 +78,13 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { %r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups) return (%r) )"; + std::string conv2d_transpose = R"( + graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], + %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, + %deterministic:bool, %cudnn_enabled:bool): + %r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + std::string conv1d = R"( graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[], %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool, @@ -124,6 +131,22 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { (calc_value_map["output_padding"].toIntList()[0] == 0) && (calc_value_map["output_padding"].toIntList()[1] == 0); }; + auto filter_conv2d_transpose = + [](const Match& match, + const std::unordered_map& vmap) { + auto calc_value_map = getConvParams(match, vmap); + if (calc_value_map["output_padding"].toIntList().size() != 2 || + calc_value_map["stride"].toIntList().size() != 2 || + calc_value_map["padding"].toIntList().size() != 2 || + calc_value_map["dilation"].toIntList().size() != 2) { + return false; + } + + return calc_value_map["transposed"].toBool() && + !calc_value_map["benchmark"].toBool() && + !calc_value_map["deterministic"].toBool() && + calc_value_map["cudnn_enabled"].toBool(); + }; auto filter_conv3d = [](const Match& match, const std::unordered_map& vmap) { auto calc_value_map = getConvParams(match, vmap); @@ -148,6 +171,10 @@ void replaceConvolutionWithAtenConv(std::shared_ptr& graph) { SubgraphRewriter rewriter_conv2d; rewriter_conv2d.RegisterRewritePattern(convolution, conv2d); rewriter_conv2d.runOnGraph(graph, filter_conv2d); + SubgraphRewriter rewriter_conv2d_transpose; + rewriter_conv2d_transpose.RegisterRewritePattern( + convolution, conv2d_transpose); + rewriter_conv2d_transpose.runOnGraph(graph, filter_conv2d_transpose); SubgraphRewriter rewriter_conv3d; rewriter_conv3d.RegisterRewritePattern(convolution, conv3d); rewriter_conv3d.runOnGraph(graph, filter_conv3d); diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index a017c28794935ff..d7b3ccf54d92331 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -35,10 +35,13 @@ bool isForLoop(Node* node) { int64_t limitedBlockSize(Block* body, int64_t limit) { auto it = body->nodes().begin(); auto end = body->nodes().end(); - for (int64_t i = 0; i < limit; ++i, ++it) { + for (int64_t i = 0; i < limit; ++it) { for (Block* subblock : it->blocks()) { i += limitedBlockSize(subblock, limit - i); } + if (!it->notExecutedOp()) { + ++i; + } if (it == end) { return i; } diff --git a/torch/csrc/jit/passes/lower_graph.cpp b/torch/csrc/jit/passes/lower_graph.cpp index 50ee4856bd0a874..581f4d14d8b42ed 100644 --- a/torch/csrc/jit/passes/lower_graph.cpp +++ b/torch/csrc/jit/passes/lower_graph.cpp @@ -34,7 +34,7 @@ std::pair, std::vector> lower_graph( std::size_t operator()(const Slot& slot) const { auto obj_hash = std::hash{}(slot.obj.get()); auto offset_hash = std::hash{}(slot.offset); - return torch::hash_combine(obj_hash, offset_hash); + return c10::hash_combine(obj_hash, offset_hash); } }; std::unordered_map slot_to_offset; diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index def4f8fd2b7ee69..15ffeb0ce7095b1 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -8,24 +8,18 @@ namespace { // map from op alias -> normalized op static const std::unordered_map alias_map = { - {aten::absolute, aten::abs}, - {aten::absolute_, aten::abs_}, - {aten::clip, aten::clamp}, - {aten::clip_, aten::clamp_}, - {aten::linalg_det, aten::det}, - {aten::outer, aten::ger}, - {aten::arccosh, aten::acosh}, - {aten::arccosh_, aten::acosh_}, - {aten::arccos, aten::acos}, - {aten::arccos_, aten::acos_}, - {aten::arcsin, aten::asin}, - {aten::arcsin_, aten::asin_}, - {aten::arctan, aten::atan}, - {aten::arctan_, aten::atan_}, - {aten::fix, aten::trunc}, - {aten::fix_, aten::trunc_}, - {aten::negative, aten::neg}, - {aten::negative_, aten::neg_}, + {aten::absolute, aten::abs}, {aten::absolute_, aten::abs_}, + {aten::clip, aten::clamp}, {aten::clip_, aten::clamp_}, + {aten::linalg_det, aten::det}, {aten::outer, aten::ger}, + {aten::arccos, aten::acos}, {aten::arccos_, aten::acos_}, + {aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_}, + {aten::arctan, aten::atan}, {aten::arctan_, aten::atan_}, + {aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_}, + {aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_}, + {aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_}, + {aten::fix, aten::trunc}, {aten::fix_, aten::trunc_}, + {aten::negative, aten::neg}, {aten::negative_, aten::neg_}, + {aten::subtract, aten::sub}, {aten::subtract_, aten::sub_}, }; void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) { diff --git a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp index 7a87f37dbc0800b..5549ea71c24f374 100644 --- a/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp +++ b/torch/csrc/jit/passes/onnx/preprocess_for_onnx.cpp @@ -67,10 +67,9 @@ void FuseWithListUnpack(Node* n) { // is aware of the number of outputs. // 2. Add the exact number of outputs to n, copy metadata and replace uses of // listUnpack outputs. - WithInsertPoint guard(n); - auto v_num_outputs = n->owningGraph()->insertConstant(at::full( - {1}, static_cast(listUnpack_node->outputs().size()), at::kLong)); - n->addInput(v_num_outputs); + n->i_( + Symbol::fromQualString("attr::_outputs"), + static_cast(listUnpack_node->outputs().size())); for (auto i = 0; i < listUnpack_node->outputs().size(); ++i) { auto new_output = n->addOutput(); @@ -96,6 +95,7 @@ static void FuseWithListUnpack(Block* b) { case aten::unsafe_split_with_sizes: case aten::unbind: case aten::unsafe_chunk: + case aten::where: FuseWithListUnpack(*it); break; default: diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 18364bac9aa63dc..4c0e0669052c002 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -134,6 +134,8 @@ Node* createInt(int64_t i, std::shared_ptr& graph) { return const_node; } +enum class QuantizedParamsType { CONV, LINEAR }; + // This is called before the onnx pass. Using pattern matching we // find the relevant nodes and extract the packed_params. The packed_params are // passed to the appropriate unpack function using c10::Dispatcher. We insert @@ -143,7 +145,8 @@ void unpackQuantizedWeightsHelper( std::shared_ptr& graph, std::map& paramsDict, const std::string& pattern, - const std::string& unpack_fn) { + const std::string& unpack_fn, + QuantizedParamsType params_type) { Graph pattern_graph; std::unordered_map vmap; parseIR(pattern, &pattern_graph, vmap); @@ -162,41 +165,95 @@ void unpackQuantizedWeightsHelper( } at::Tensor unpacked_weight; c10::optional bias; - const int64_t stride_idx = 2; - const int64_t padding_idx = 3; - const int64_t dilation_idx = 4; - const int64_t groups_idx = 5; + constexpr int64_t stride_idx = 2; + constexpr int64_t padding_idx = 3; + constexpr int64_t dilation_idx = 4; + constexpr int64_t groups_idx = 5; c10::optional> stride, padding, dilation; c10::optional groups; + torch::List stride_int, padding_int, dilation_int; + int64_t groups_int; + if (itr->second.isTuple()) { // Pre-unpacked weights. Comes from Conv/Linear weights which are // stored as bound C++ classes. auto ser_tup = itr->second.toTuple(); - unpacked_weight = ser_tup->elements()[0].toTensor(); - bias = ser_tup->elements()[1].toOptional(); - // conv only parameters - if (ser_tup->elements().size() > 2) { - auto stride_ivalue = ser_tup->elements()[stride_idx].toListRef(); - auto padding_ivalue = ser_tup->elements()[padding_idx].toListRef(); - auto dilation_ivalue = ser_tup->elements()[dilation_idx].toListRef(); - auto groups_ivalue = ser_tup->elements()[groups_idx]; - torch::List stride_int, padding_int, dilation_int; - int64_t groups_int; - for (const auto& s : stride_ivalue) { - stride_int.emplace_back(s.toTensor()[0].item()); + + if (params_type == QuantizedParamsType::CONV && + ser_tup->elements()[0].isString()) { + auto elements = ser_tup->elements(); + auto version = elements[0].toStringRef(); + TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version"); + std::vector non_optional = elements[1].toTensorVector(); + + at::Tensor conv_params_packed = non_optional[0]; + unpacked_weight = non_optional[1]; + + const int64_t kSpatialDim = conv_params_packed[0].item(); + // skip kSpatialDim + int64_t idx = 1; + for (int i = 0; i < kSpatialDim; ++i) { + stride_int.emplace_back(conv_params_packed[idx].item()); + idx++; + } + for (int i = 0; i < kSpatialDim; ++i) { + padding_int.emplace_back(conv_params_packed[idx].item()); + idx++; } - for (const auto& p : padding_ivalue) { - padding_int.emplace_back(p.toTensor()[0].item()); + for (int i = 0; i < kSpatialDim; ++i) { + dilation_int.emplace_back(conv_params_packed[idx].item()); + idx++; } - for (const auto& d : dilation_ivalue) { - dilation_int.emplace_back(d.toTensor()[0].item()); + // output_padding is not implemented yet, so we skip the entries + for (int i = 0; i < kSpatialDim; ++i) { + // do nothing + idx++; } - groups_int = groups_ivalue.toTensor()[0].item(); + groups_int = conv_params_packed[idx].item(); + idx++; + // skip transpose + idx++; + TORCH_INTERNAL_ASSERT( + idx == conv_params_packed.numel(), + "Unexpected length of conv_params_packed, expected ", + idx, + " got ", + conv_params_packed.numel()); + + torch::List optional = elements[2].toList(); + bias = optional.get(0).toOptional(); + stride = stride_int; padding = padding_int; dilation = dilation_int; groups = groups_int; + + } else { // Legacy + unpacked_weight = ser_tup->elements()[0].toTensor(); + bias = ser_tup->elements()[1].toOptional(); + // conv only parameters + if (ser_tup->elements().size() > 2) { + auto stride_ivalue = ser_tup->elements()[stride_idx].toListRef(); + auto padding_ivalue = ser_tup->elements()[padding_idx].toListRef(); + auto dilation_ivalue = ser_tup->elements()[dilation_idx].toListRef(); + auto groups_ivalue = ser_tup->elements()[groups_idx]; + + for (const auto& s : stride_ivalue) { + stride_int.emplace_back(s.toTensor()[0].item()); + } + for (const auto& p : padding_ivalue) { + padding_int.emplace_back(p.toTensor()[0].item()); + } + for (const auto& d : dilation_ivalue) { + dilation_int.emplace_back(d.toTensor()[0].item()); + } + groups_int = groups_ivalue.toTensor()[0].item(); + stride = stride_int; + padding = padding_int; + dilation = dilation_int; + groups = groups_int; + } } } else { TORCH_INTERNAL_ASSERT(itr->second.isTensor()); @@ -280,7 +337,7 @@ void unpackQuantizedWeightsHelper( c2_bias->insertBefore(qlinear_node); qlinear_node->insertInput(2, c2_bias->output()); - // add conv arguemnts: stride, padding, dilation, groups + // add conv arguments: stride, padding, dilation, groups if (stride.has_value() && padding.has_value() && dilation.has_value() && groups.has_value()) { std::vector>> conv_ints_args; @@ -327,15 +384,35 @@ void UnpackQuantizedWeights( %r = quantized::conv3d_relu(%input, %packed_params, %scale, %zero_point) return (%r) )"; unpackQuantizedWeightsHelper( - graph, paramsDict, qlinear, "quantized::linear_unpack"); + graph, + paramsDict, + qlinear, + "quantized::linear_unpack", + QuantizedParamsType::LINEAR); unpackQuantizedWeightsHelper( - graph, paramsDict, qconv2d, "quantized::conv2d_unpack"); + graph, + paramsDict, + qconv2d, + "quantized::conv2d_unpack", + QuantizedParamsType::CONV); unpackQuantizedWeightsHelper( - graph, paramsDict, qconv2d_relu, "quantized::conv2d_unpack"); + graph, + paramsDict, + qconv2d_relu, + "quantized::conv2d_unpack", + QuantizedParamsType::CONV); unpackQuantizedWeightsHelper( - graph, paramsDict, qconv3d, "quantized::conv3d_unpack"); + graph, + paramsDict, + qconv3d, + "quantized::conv3d_unpack", + QuantizedParamsType::CONV); unpackQuantizedWeightsHelper( - graph, paramsDict, qconv3d_relu, "quantized::conv3d_unpack"); + graph, + paramsDict, + qconv3d_relu, + "quantized::conv3d_unpack", + QuantizedParamsType::CONV); } // Caffe2 expects quantized ops to be in NHWC format while pytorch inputs are in diff --git a/torch/csrc/jit/passes/specialize_autogradzero.cpp b/torch/csrc/jit/passes/specialize_autogradzero.cpp index bc5fb2f7ff2f141..208bc04a62f7d40 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.cpp +++ b/torch/csrc/jit/passes/specialize_autogradzero.cpp @@ -1,154 +1,350 @@ #include +#include +#include +#include +#include #include namespace torch { namespace jit { -// propagate autograd zero information through a gradient graph and -// remove grad_of blocks if present. -// Note: this is a very limited pass. It only propagates autograd zeros for -// operations generated by the symbolic autodiff code and cleans up -// AutogradAdds when possible. Outputs of other nodes are conservatively -// marked Unknown and not optimized. -void specializeAutogradZero(Graph& g) { +struct AutogradZeroSpecializer { enum class State { Nonzero, Zero, Unknown }; - std::unordered_map state; - - for (Value* input : g.inputs()) { - const auto& tp = input->type(); - if (auto tt = tp->cast()) { - if (tt->undefined()) { - if (*tt->undefined()) { - state[input] = State::Zero; - } else { - state[input] = State::Nonzero; - } - } else { - state[input] = State::Unknown; + + AutogradZeroSpecializer(std::shared_ptr graph) + : graph_(std::move(graph)) {} + + void run() { + if (!isBackwardGraph()) { + return; + } + if (getProfilingMode()) { + if (auto versioning_if = guardSpecializations()) { + specializeAutogradOps(versioning_if->blocks()[0]); + GRAPH_DUMP("After versioning graph", graph_); } - } else if ( - tp->isSubtypeOf(TensorType::get()) || - tp->isSubtypeOf(ListType::ofTensors())) { - state[input] = State::Nonzero; } else { - state[input] = State::Unknown; + setStatesOnGraphInputs(); + specializeAutogradOps(graph_->block()); } + GRAPH_DUMP("After specializeAutogradOps graph", graph_); } - for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) { - auto n = *it; - - switch (n->kind()) { - case prim::AutogradAdd: { - auto a = n->input(0); - auto b = n->input(1); - // if one is Autograd zero, we can just drop the add - if (state[a] == State::Zero) { - // Zero + b == b - n->output()->replaceAllUsesWith(b); - it.destroyCurrent(); - } else if (state[b] == State::Zero) { - // a + Zero == a - n->output()->replaceAllUsesWith(a); - it.destroyCurrent(); - } else if (state[a] == State::Nonzero && state[b] == State::Nonzero) { - // when both are Nonzero, we can use a normal, optimizable add - // instruction - WithInsertPoint guard(n); - auto* g = n->owningGraph(); - auto* cOne = g->insertConstant(1); - auto* add_node = g->insertNode(g->create(aten::add, 1)); - add_node->addInput(a); - add_node->addInput(b); - add_node->addInput(cOne); - auto* add_output = add_node->output(); - add_output->setType(n->output()->type()); - state[add_output] = State::Nonzero; - n->output()->replaceAllUsesWith(add_output); - it.destroyCurrent(); + private: + bool isBackwardGraph() { + return std::any_of( + graph_->nodes().begin(), graph_->nodes().end(), [](Node* n) { + switch (n->kind()) { + case prim::AutogradAnyNonZero: + case prim::AutogradAdd: + case aten::_grad_sum_to_size: + return true; + default: + return false; + } + }); + } + + void replaceBlockInputsWithGraphInputs(Block* b) { + TORCH_INTERNAL_ASSERT(graph_->inputs().size() == b->inputs().size()); + size_t num_inputs = graph_->inputs().size(); + for (size_t i = 0; i < num_inputs; ++i) { + b->inputs().at(i)->replaceAllUsesWith(graph_->inputs().at(i)); + } + for (size_t i = 0; i < num_inputs; ++i) { + b->eraseInput(num_inputs - (1 + i)); + } + } + + void setStatesOnGraphInputs() { + for (Value* input : graph_->inputs()) { + const auto& tp = input->type(); + if (auto tt = tp->cast()) { + if (tt->undefined()) { + if (*tt->undefined()) { + state_[input] = State::Zero; + } else { + state_[input] = State::Nonzero; + } } else { - // otherwise we have conditionally-Nonzero things, and we need - // to actually run an AutogradAdd which will guard for Zeros - // so we leave the op as is - state[n->output()] = State::Unknown; + state_[input] = State::Unknown; } - } break; - case prim::AutogradZero: { - state[n->output()] = State::Zero; - } break; - case prim::profile: { - state[n->output()] = State::Unknown; - break; + } else if ( + tp->isSubtypeOf(TensorType::get()) || + tp->isSubtypeOf(ListType::ofTensors())) { + state_[input] = State::Nonzero; + } else { + state_[input] = State::Unknown; } - case prim::BailOut: { - if (auto ptt = n->output()->type()->expect()) { - state[n->output()] = ptt->undefined() - ? *ptt->undefined() ? State::Zero : State::Nonzero - : State::Unknown; - } - } break; - case prim::Guard: { - if (auto ptt = n->output()->type()->expect()) { - state[n->output()] = ptt->undefined() - ? *ptt->undefined() ? State::Zero : State::Nonzero - : State::Unknown; + } + } + + static Node* getUse(Value* inp, Symbol kind) { + for (auto use : inp->uses()) { + if (use.user->kind() == kind) { + return use.user; + } + } + + return nullptr; + } + + void removeProfiledOptionalUses(Value* v) { + std::vector profiled_opt_uses; + for (const Use& use : v->uses()) { + if (use.user->kind() == prim::profile_optional) { + profiled_opt_uses.push_back(use.user); + } + } + for (Node* n : profiled_opt_uses) { + n->output()->replaceAllUsesWith(v); + n->destroy(); + } + } + + Node* guardSpecializations() { + auto versioning_if = graph_->create(prim::If, {}, graph_->outputs().size()); + auto value_map = [](Value* v) { return v; }; + auto true_block = versioning_if->addBlock(); + auto false_block = versioning_if->addBlock(); + + // we will optimize true_block + true_block->cloneFrom(graph_->block(), value_map); + replaceBlockInputsWithGraphInputs(true_block); + false_block->cloneFrom(graph_->block(), value_map); + replaceBlockInputsWithGraphInputs(false_block); + + WithInsertPoint wip{graph_->block()}; + Value* none_val = graph_->insertConstant(IValue()); + std::vector checks; + + for (auto inp : graph_->inputs()) { + if (auto profile_optional_node = getUse(inp, prim::profile_optional)) { + if (profile_optional_node->i(attr::num_present) == 0 && + profile_optional_node->i(attr::num_none) != 0) { + auto check = graph_->insert(aten::__is__, {inp, none_val})->node(); + checks.push_back(check->output()); + profiled_none_.insert(inp); } - } break; - // Lowered GradOf block - case prim::If: { - auto if_input = n->input(0)->node(); - if (if_input->kind() == prim::AutogradAnyNonZero) { - auto all_zeros = std::all_of( - if_input->inputs().begin(), - if_input->inputs().end(), - [&](Value* v) { return state[v] == State::Zero; }); - - auto all_nonzeros = std::all_of( - if_input->inputs().begin(), - if_input->inputs().end(), - [&](Value* v) { return state[v] == State::Nonzero; }); - // Property 1: if all the gradInputs to the GradOf are Zero - // then the gradOutputs are also zero and will be represented as - // AutogradZero nodes - if (all_zeros) { - auto zero = g.createAutogradZero()->insertAfter(n)->output(); - state[zero] = State::Zero; - for (auto o : n->outputs()) { - o->replaceAllUsesWith(zero); - } + removeProfiledOptionalUses(inp); + continue; + } + + if (inp->uses().size() == 0 || !inp->type()->cast()) { + continue; + } + + // TODO: check multiple uses ? + auto pout = getUse(inp, prim::profile); + if (!pout) { + continue; + } + + auto pttp = pout->ty(attr::profiled_type)->expect(); + if (!pttp->undefined().has_value()) { + continue; + } + + state_[inp] = *pttp->undefined() ? State::Zero : State::Nonzero; + auto check = graph_->insert(prim::AutogradAnyNonZero, {inp}); + if (!*pttp->undefined()) { + check = graph_->insert(aten::__not__, {check}); + } + checks.push_back(check); + } + + // unable to specialize any of the inputs + if (checks.size() == 0) { + GRAPH_DUMP("Unable to add any specialization guards", graph_); + versioning_if->destroy(); + // the checks we inserted will be cleaned up + // by any subsequent DCE pass + return nullptr; + } + + Value* bool_list = + graph_->insertNode(graph_->createList(BoolType::get(), checks)) + ->output(); + Value* conjunction = graph_->insert(aten::all, {bool_list}); + + versioning_if->addInput(conjunction); + graph_->insertNode(versioning_if); + + auto ret = graph_->return_node(); + for (size_t i = 0; i < ret->inputs().size(); i++) { + auto ogo = ret->input(i); + auto ngo = versioning_if->output(i); + ngo->copyMetadata(ogo); + ret->replaceInput(i, ngo); + } + + // We've created: + // succesful_checks = Guards(...) + // if (succesful_checks) + // -> optimized graph + // else: + // -> fallback graph + // original graph + // + // Remove the dead original graph + for (auto it = graph_->block()->nodes().reverse().begin(); + *it != versioning_if;) { + Node* n = *it; + it++; + n->destroy(); + } + + GRAPH_DUMP("After guardSpecializations", graph_); + return versioning_if; + } + + void specializeAutogradOps(Block* block) { + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + auto n = *it; + switch (n->kind()) { + case prim::AutogradAdd: { + auto a = n->input(0); + auto b = n->input(1); + // if one is Autograd zero, we can just drop the add + if (state_[a] == State::Zero) { + // Zero + b == b + n->output()->replaceAllUsesWith(b); + it.destroyCurrent(); + } else if (state_[b] == State::Zero) { + // a + Zero == a + n->output()->replaceAllUsesWith(a); + it.destroyCurrent(); + } else if ( + state_[a] == State::Nonzero && state_[b] == State::Nonzero) { + // when both are Nonzero, we can use a normal, optimizable add + // instruction + WithInsertPoint guard(n); + auto* cOne = graph_->insertConstant(1); + auto* add_node = graph_->insertNode(graph_->create(aten::add, 1)); + add_node->addInput(a); + add_node->addInput(b); + add_node->addInput(cOne); + auto* add_output = add_node->output(); + add_output->setType(n->output()->type()); + state_[add_output] = State::Nonzero; + n->output()->replaceAllUsesWith(add_output); it.destroyCurrent(); - break; + } else { + // otherwise we have conditionally-Nonzero things, and we need + // to actually run an AutogradAdd which will guard for Zeros + // so we leave the op as is + state_[n->output()] = State::Unknown; } + } break; + case prim::AutogradZero: { + state_[n->output()] = State::Zero; + } break; + case prim::profile: { + // this a profile node on a tensor use + // if we decided to specialize this graph + // its input may have undefinedness info + // otherwise it should be Unknown + if (n->inputs().size() > 0) { + state_[n->output()] = !state_.count(n->input()) + ? State::Unknown + : state_[n->output()] = state_[n->input()]; + } + break; + } + case prim::BailOut: { + if (auto ptt = n->output()->type()->expect()) { + state_[n->output()] = ptt->undefined() + ? *ptt->undefined() ? State::Zero : State::Nonzero + : State::Unknown; + } + } break; + // Lowered GradOf block + case prim::If: { + auto if_input = n->input(0)->node(); + if (if_input->kind() == prim::AutogradAnyNonZero) { + auto all_zeros = std::all_of( + if_input->inputs().begin(), + if_input->inputs().end(), + [&](Value* v) { return state_[v] == State::Zero; }); - if (all_nonzeros) { - auto body = n->blocks().at(0); - // hoist the nodes in the GradOf body to be before the linear block - for (auto it = body->nodes().begin(); it != body->nodes().end();) { - auto block_node = *it++; - block_node->moveBefore(n); + auto all_nonzeros = std::all_of( + if_input->inputs().begin(), + if_input->inputs().end(), + [&](Value* v) { return state_[v] == State::Nonzero; }); + // Property 1: if all the gradInputs to the GradOf are Zero + // then the gradOutputs are also zero and will be represented as + // AutogradZero nodes + if (all_zeros) { + auto zero = + graph_->createAutogradZero()->insertAfter(n)->output(); + state_[zero] = State::Zero; + for (auto o : n->outputs()) { + o->replaceAllUsesWith(zero); + } + it.destroyCurrent(); + break; } - for (size_t i = 0; i < n->outputs().size(); ++i) { - n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i)); - state[body->outputs().at(i)] = State::Nonzero; + specializeGradSumToSize(n->blocks().at(0)); + if (all_nonzeros) { + auto body = n->blocks().at(0); + // hoist the nodes in the GradOf body to be before the linear + // block + for (auto it = body->nodes().begin(); + it != body->nodes().end();) { + auto block_node = *it++; + block_node->moveBefore(n); + } + + for (size_t i = 0; i < n->outputs().size(); ++i) { + n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i)); + state_[body->outputs().at(i)] = State::Nonzero; + } + it.destroyCurrent(); + break; } - it.destroyCurrent(); - break; } - } - for (auto o : n->outputs()) { - state[o] = State::Unknown; + for (auto o : n->outputs()) { + state_[o] = State::Unknown; + } + break; } - break; + default: + for (auto o : n->outputs()) { + state_[o] = State::Unknown; + } + break; } - default: - for (auto o : n->outputs()) { - state[o] = State::Unknown; + } + } + + void specializeGradSumToSize(Block* b) { + for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { + Node* n = *it; + if (n->kind() == aten::_grad_sum_to_size) { + if (n->input(1)->mustBeNone() || profiled_none_.count(n->input(1))) { + n->output()->replaceAllUsesWith(n->input(0)); + it.destroyCurrent(); } - break; + } } } + + std::shared_ptr graph_; + std::unordered_set profiled_none_; + std::unordered_map state_; +}; + +// propagate autograd zero information through a gradient graph and +// remove grad_of blocks if present. +// Note: this is a very limited pass. It only propagates autograd zeros for +// operations generated by the symbolic autodiff code and cleans up +// AutogradAdds when possible. Outputs of other nodes are conservatively +// marked Unknown and not optimized. +void specializeAutogradZero(std::shared_ptr g) { + AutogradZeroSpecializer azs(g); + azs.run(); } } // namespace jit diff --git a/torch/csrc/jit/passes/specialize_autogradzero.h b/torch/csrc/jit/passes/specialize_autogradzero.h index 41f2e5a047cbf83..02061f3454abe0b 100644 --- a/torch/csrc/jit/passes/specialize_autogradzero.h +++ b/torch/csrc/jit/passes/specialize_autogradzero.h @@ -11,7 +11,7 @@ namespace jit { // operations generated by the symbolic autodiff code and cleans up // AutogradAdds when possible. Outputs of other nodes are conservatively // marked Unknown and not optimized. -TORCH_API void specializeAutogradZero(Graph& g); +TORCH_API void specializeAutogradZero(std::shared_ptr g); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index e1dad93dd838525..dc1ef4940c9f8b9 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -14,8 +14,22 @@ namespace torch { namespace jit { +bool isSupportedForBlock(Node* node) { + switch (node->kind()) { + case aten::add: + case aten::mul: + return true; + default: + return false; + } +} + namespace tensorexpr { bool isSupported(Node* node) { + // For Block codegen we allow limited ops. + if (tensorexpr::getTEGenerateBlockCode()) { + return isSupportedForBlock(node); + } // TODO: switch (node->kind()) { case aten::add: @@ -67,6 +81,7 @@ bool isSupported(Node* node) { case aten::addcmul: case aten::neg: case aten::reciprocal: + case aten::sum: case aten::expm1: case aten::lgamma: case aten::unsqueeze: @@ -100,6 +115,7 @@ bool isSupported(Node* node) { return false; } } + } // namespace tensorexpr static bool texpr_fuser_enabled_ = false; @@ -618,6 +634,11 @@ class TensorExprFuser { void FuseTensorExprs(std::shared_ptr& graph, size_t min_group_size) { GRAPH_DUMP("Before TExprFuser: ", graph); + // Temporary change for Block code generation. + if (tensorexpr::getTEGenerateBlockCode()) { + min_group_size = 1; + } + // Get rid of dead code so that we don't waste effort fusing it. EliminateDeadCode(graph); diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 83ba6ea5166e901..3ebfab1d32646cb 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -143,6 +143,26 @@ void insertPrePackedConv2dOp(std::shared_ptr& graph) { rewriter.RegisterRewritePattern( conv_2d_pattern, prepacked_ops_conv2d_pattern); rewriter.runOnGraph(graph); + + std::string conv_2d_transpose_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], + %output_padding:int[], %groups:int): + %r = aten::conv_transpose2d(%input, %weight, %bias, %stride, %padding, %output_padding, %groups, %dilation) + return (%r) )"; + + std::string prepacked_ops_conv2d_transpose_pattern = R"( + graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %output_padding:int[], %groups:int): + %output_min_max : None = prim::Constant() + %packed_weight_bias = prepacked::conv2d_transpose_clamp_prepack( + %weight, %bias, %stride, %padding, %output_padding, %dilation, %groups, + %output_min_max, %output_min_max) + %r = prepacked::conv2d_transpose_clamp_run(%input, %packed_weight_bias) + return (%r) )"; + + SubgraphRewriter transpose_rewriter; + transpose_rewriter.RegisterRewritePattern( + conv_2d_transpose_pattern, prepacked_ops_conv2d_transpose_pattern); + transpose_rewriter.runOnGraph(graph); } void fuseHardtanhWithPackedOps(std::shared_ptr& graph) { @@ -321,7 +341,11 @@ void FoldPrePackingOps(script::Module& m) { return ( (n->kind() == Symbol::fromQualString("prepacked::linear_clamp_prepack")) || - n->kind() == Symbol::fromQualString("prepacked::conv2d_clamp_prepack")); + n->kind() == + Symbol::fromQualString("prepacked::conv2d_clamp_prepack") || + n->kind() == + Symbol::fromQualString( + "prepacked::conv2d_transpose_clamp_prepack")); }; PrePackingOpsFolder(m, filter_fn, "prepack_folding"); } @@ -397,7 +421,7 @@ script::Module optimizeForMobile( const std::set& blocklist, const std::vector& preserved_methods) { TORCH_INTERNAL_ASSERT( - "Mobile optimizaiton only available with XNNPACK at the moment. " + "Mobile optimization only available with XNNPACK at the moment. " "XNNPACK is not enabled. Please build with USE_XNNPACK=1"); return module; } diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a45190ff1ef91c5..5d2248df101e356 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -554,6 +554,18 @@ void initJITBindings(PyObject* module) { .def("_jit_texpr_fuser_enabled", &tensorExprFuserEnabled) .def("_jit_texpr_fallback_allowed", &tensorexpr::fallbackAllowed) .def("_jit_texpr_set_fallback_allowed", &tensorexpr::setFallbackAllowed) + .def( + "_jit_set_te_generate_block_code", + [](bool gen_block_code) { + using namespace torch::jit::tensorexpr; + return getTEGenerateBlockCode() = gen_block_code; + }) + .def( + "_jit_get_te_generate_block_code", + []() -> bool { + using namespace torch::jit::tensorexpr; + return getTEGenerateBlockCode(); + }) .def( "_jit_pass_fuse_tensorexprs", [](std::shared_ptr& g) { return FuseTensorExprs(g); }) diff --git a/torch/csrc/jit/python/python_arg_flatten.h b/torch/csrc/jit/python/python_arg_flatten.h index a67d76437ac9f7a..29af52f3b294302 100644 --- a/torch/csrc/jit/python/python_arg_flatten.h +++ b/torch/csrc/jit/python/python_arg_flatten.h @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include #include @@ -27,7 +27,7 @@ struct IODescriptor { } static size_t hash(const VariableMetadata& m) { - return get_hash(m.sizes, m.device, m.requires_grad, m.type); + return c10::get_hash(m.sizes, m.device, m.requires_grad, m.type); } std::vector sizes; @@ -42,7 +42,7 @@ struct IODescriptor { } static size_t hash(const IODescriptor& o) { - return get_hash(o.structure, o.metadata, o.grad_enabled); + return c10::get_hash(o.structure, o.metadata, o.grad_enabled); } void extend(const autograd::variable_list& list) { diff --git a/torch/csrc/jit/python/python_tree_views.cpp b/torch/csrc/jit/python/python_tree_views.cpp index 596531865ce3c61..1d599478418aae7 100644 --- a/torch/csrc/jit/python/python_tree_views.cpp +++ b/torch/csrc/jit/python/python_tree_views.cpp @@ -83,6 +83,12 @@ void initTreeViewBindings(PyObject* module) { self.highlight(stream); return stream.str(); }) + .def("__repr__", [](const SourceRange& self) { return self.str(); }) + .def( + "__str__", + [](const SourceRange& self) { + return "SourceRange at:\n" + self.str(); + }) .def_property_readonly("start", &SourceRange::start) .def_property_readonly("end", &SourceRange::end); py::class_(m, "SourceRangeFactory") diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index b844d03a00c7284..d47f4445b7cbe0f 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -45,8 +45,6 @@ #include #include -PYBIND11_MAKE_OPAQUE(torch::jit::ExtraFilesMap); - namespace torch { namespace jit { @@ -711,13 +709,24 @@ IValue pyIValueDeepcopy(const IValue& ivalue, const py::dict& memo) { return ivalue.deepcopy(ivalue_memo); } +ExtraFilesMap extra_files_from_python(const py::dict& pydict) { + ExtraFilesMap r; + for (const auto& it : pydict) { + r[py::cast(it.first)] = ""; + } + return r; +} + +void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) { + // py::dict is pointer-like type so it gets modified despite const& + for (const auto& it : m) { + pydict[py::str(it.first)] = py::bytes(it.second); + } +} + void initJitScriptBindings(PyObject* module) { auto m = py::handle(module).cast(); - // STL containers are not mutable by default and hence we need to bind as - // follows. - py::bind_map(m, "ExtraFilesMap"); - // NOLINTNEXTLINE(bugprone-unused-raii) py::class_>(m, "Capsule"); @@ -1363,22 +1372,25 @@ void initJitScriptBindings(PyObject* module) { [](std::shared_ptr cu, const std::string& filename, py::object map_location, - ExtraFilesMap& extra_files) { + const py::dict& extra_files) { c10::optional optional_device; if (!map_location.is(py::none())) { AT_ASSERT(THPDevice_Check(map_location.ptr())); optional_device = reinterpret_cast(map_location.ptr())->device; } - return import_ir_module( - std::move(cu), filename, optional_device, extra_files); + ExtraFilesMap extra_files_map = extra_files_from_python(extra_files); + auto ret = import_ir_module( + std::move(cu), filename, optional_device, extra_files_map); + extra_files_to_python(extra_files_map, extra_files); + return ret; }); m.def( "import_ir_module_from_buffer", [](std::shared_ptr cu, const std::string& buffer, py::object map_location, - ExtraFilesMap& extra_files) { + const py::dict& extra_files) { std::istringstream in(buffer); c10::optional optional_device; if (!map_location.is(py::none())) { @@ -1386,8 +1398,11 @@ void initJitScriptBindings(PyObject* module) { optional_device = reinterpret_cast(map_location.ptr())->device; } - return import_ir_module( - std::move(cu), in, optional_device, extra_files); + ExtraFilesMap extra_files_map = extra_files_from_python(extra_files); + auto ret = import_ir_module( + std::move(cu), in, optional_device, extra_files_map); + extra_files_to_python(extra_files_map, extra_files); + return ret; }); m.def( "_load_for_lite_interpreter", diff --git a/torch/csrc/jit/runtime/argument_spec.h b/torch/csrc/jit/runtime/argument_spec.h index 966e86afbfcaad0..401933c6d67e710 100644 --- a/torch/csrc/jit/runtime/argument_spec.h +++ b/torch/csrc/jit/runtime/argument_spec.h @@ -2,15 +2,13 @@ #include #include +#include #include #include #include -#include #include #include -#include - namespace torch { namespace jit { @@ -74,7 +72,8 @@ static_assert( struct ArgumentSpec { ArgumentSpec(size_t num_flat_tensor_inputs, size_t num_flat_optional_inputs) { - hash_code = hash_combine(num_flat_tensor_inputs, num_flat_optional_inputs); + hash_code = + c10::hash_combine(num_flat_tensor_inputs, num_flat_optional_inputs); tensor_args.reserve(num_flat_tensor_inputs); optional_presence.reserve(num_flat_optional_inputs); } @@ -82,7 +81,7 @@ struct ArgumentSpec { void addOptional(const IValue& input) { bool is_present = !input.isNone(); optional_presence.push_back(is_present); - hash_code = hash_combine(hash_code, is_present); + hash_code = c10::hash_combine(hash_code, is_present); } void addTensor(const IValue& input, bool with_grad) { @@ -111,7 +110,7 @@ struct ArgumentSpec { void combineHash(const ArgumentInfo& arg) { ArgumentInfo::plain_data_type arg_data; std::memcpy(&arg_data, &arg, sizeof(ArgumentInfo)); - hash_code = hash_combine(hash_code, arg_data); + hash_code = c10::hash_combine(hash_code, arg_data); } // equality is fast: check ninputs, and then check the raw array data, @@ -272,9 +271,9 @@ struct CompleteArgumentSpec { } // we precompute the hash_code to minimize the time inside of hash // table operations where we may need to hold a compiler cache lock. - hash_code = hash_combine(0, ninputs); + hash_code = c10::hash_combine(0, ninputs); for (auto d : data) { - hash_code = hash_combine(hash_code, d); + hash_code = c10::hash_combine(hash_code, d); } } @@ -447,7 +446,7 @@ namespace std { template struct hash> { size_t operator()(const c10::VaryingShape& vs) const { - return torch::get_hash( + return c10::get_hash( vs.size(), vs.size() ? vs.sizes().value() : std::vector>()); } @@ -456,7 +455,7 @@ struct hash> { template <> struct hash { size_t operator()(const c10::TensorType& ptt) const { - return torch::get_hash< + return c10::get_hash< c10::optional, c10::VaryingShape, c10::VaryingShape, diff --git a/torch/csrc/jit/runtime/autodiff.cpp b/torch/csrc/jit/runtime/autodiff.cpp index 71a0b70e60c9504..645491434f7d9f3 100644 --- a/torch/csrc/jit/runtime/autodiff.cpp +++ b/torch/csrc/jit/runtime/autodiff.cpp @@ -58,7 +58,8 @@ bool isDifferentiable(Node* n) { // Tensor", "aten::min(Tensor self) -> Tensor" if (n->kind() == prim::Constant || n->kind() == prim::AutogradZero || - n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk) + n->kind() == prim::AutogradAdd || n->kind() == prim::ConstantChunk || + n->kind() == prim::profile) return true; if (n->isMemberOf(differentiable_ops)) @@ -208,6 +209,8 @@ class GradientHelper { if (node->kind() == prim::AutogradAdd) { // NB: AutogradAdds don't broadcast return {grad_values.at(0), grad_values.at(0)}; + } else if (node->kind() == prim::profile) { + return {grad_values.at(0)}; } else if (node->kind() == prim::ConstantChunk) { auto* g = node->owningGraph(); diff --git a/torch/csrc/jit/runtime/custom_operator.h b/torch/csrc/jit/runtime/custom_operator.h index d1fe948e15ea002..45ad6676376ceb2 100644 --- a/torch/csrc/jit/runtime/custom_operator.h +++ b/torch/csrc/jit/runtime/custom_operator.h @@ -17,9 +17,13 @@ struct TORCH_API RegisterOperators { RegisterOperators() = default; /// Registers a vector of already created `Operator`s. - RegisterOperators(std::vector operators) { - for (Operator& o : operators) { - registerOperator(std::move(o)); + /// The operator element is now optional to filter null ops. It's backward + /// compatible and works for selective operator registration. + RegisterOperators(std::vector> operators) { + for (c10::optional& o : operators) { + if (o) { + registerOperator(std::move(o.value())); + } } } }; diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index a0e9c2d73847836..be3c5ca26d38aa4 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -53,6 +53,20 @@ namespace torch { namespace jit { +EnableProfilingGuard::EnableProfilingGuard() { + auto& profiling_mode = getProfilingMode(); + old_profiling_mode = profiling_mode; + profiling_mode = true; + auto& executor_mode = getExecutorMode(); + old_executor_mode = executor_mode; + executor_mode = true; +} + +EnableProfilingGuard::~EnableProfilingGuard() { + getProfilingMode() = old_profiling_mode; + getExecutorMode() = old_executor_mode; +} + namespace { c10::AliasAnalysisKind aliasAnalysisInternalSpecialCase() { return AliasAnalysisKind::INTERNAL_SPECIAL_CASE; @@ -588,7 +602,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { GRAPH_DUMP("After Inline, before LowerGradOf", opt_graph); LowerGradOf(*opt_graph); GRAPH_DUMP("After LowerGradOf, before specializeAutogradZero", opt_graph); - specializeAutogradZero(*opt_graph); + specializeAutogradZero(opt_graph); GRAPH_DUMP( "After specializeAutogradZero, before LowerSimpleTuples", opt_graph); LowerSimpleTuples(opt_graph); diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index 7dda3406fed274b..1f754067bb44339 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -43,6 +43,15 @@ struct GraphExecutorState { std::unordered_map execution_plans; }; +struct TORCH_API EnableProfilingGuard { + EnableProfilingGuard(); + ~EnableProfilingGuard(); + + private: + bool old_executor_mode = false; + bool old_profiling_mode = false; +}; + struct GraphExecutorImplBase; struct TORCH_API GraphExecutor { GraphExecutor() = default; diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 52c32c08521fca4..f693186dc4d7ba7 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -754,7 +754,14 @@ struct CodeImpl { void emitProfile(Node* node) { emitLoadInputs(node->inputs()); insertInstruction(PROFILE_OP, profile_function_table_.size()); - profile_function_table_.push_back(node->cast()->getCallback()); + if (node->cast()) { + profile_function_table_.push_back(node->cast()->getCallback()); + } else if (node->cast()) { + profile_function_table_.push_back( + node->cast()->getCallback()); + } else { + TORCH_INTERNAL_ASSERT(false); + } } void emitGetAttr(Node* node) { @@ -902,6 +909,7 @@ struct CodeImpl { case prim::BailOut: emitBailOut(node); break; + case prim::profile_optional: case prim::profile: emitProfile(node); break; diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 2d11b69d27b7fae..70c88eeab7b42ee 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -239,7 +239,9 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::MMBatchSide, // used as an optimization prim::Store, // used in interpreter only prim::profile, // used in interpreter only + prim::profile_optional, // used in interpreter only prim::TypeCheck, // used in interpreter only + prim::FallbackGraph, // converted into prim::CallFunction }; @@ -294,6 +296,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::GetAttr, prim::SetAttr, prim::profile, + prim::profile_optional, prim::TypeCheck, prim::Print, prim::CallFunction, @@ -305,6 +308,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::rpc_async, prim::Enter, prim::Exit, + prim::FallbackGraph, }; // Operators that should not be used by alias analysis diff --git a/torch/csrc/jit/runtime/operator.h b/torch/csrc/jit/runtime/operator.h index 6acc2aee7d7fae3..05305c71f27c547 100644 --- a/torch/csrc/jit/runtime/operator.h +++ b/torch/csrc/jit/runtime/operator.h @@ -5,10 +5,12 @@ #include #include +#include #include #include #include #include +#include #include #include @@ -223,5 +225,26 @@ TORCH_API void ensure_c10_registerer_defined(); // Used to assert that unschematized operators have an analysis method written TORCH_API bool aliasAnalysisHasSpecialCaseFor(c10::Symbol sym); +// A factory function to generate an optional operator. It has two +// instantiations depending on the template bool arg value. The arg can be a +// compile-time function for the selective op registration based on schema +// string. +template +c10::optional OperatorGenerator( + torch::detail::SelectiveStr schema_str, + Func&& op, + AliasAnalysisKind alias_analysis) { + return c10::optional(Operator( + std::string(schema_str), std::forward(op), alias_analysis)); +} + +template +c10::optional OperatorGenerator( + torch::detail::SelectiveStr schema_str, + Func&& op, + AliasAnalysisKind alias_analysis) { + return c10::nullopt; +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 28e277c6d8a3d09..f52e05848f57190 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -73,6 +73,12 @@ static bool needsGradientInProfilingMode(Block* b) { return true; } } + if (n->kind() == prim::profile) { + auto type = n->ty(attr::profiled_type)->expect(); + if (type->requiresGrad() && *type->requiresGrad()) { + return true; + } + } for (auto ib : n->blocks()) { if (needsGradientInProfilingMode(ib)) { @@ -114,7 +120,7 @@ void runPreAutodiffPassPipeline(std::shared_ptr& graph) { GRAPH_DUMP("After InsertBailOuts, before specializeAutogradZero", graph); } - specializeAutogradZero(*graph); + specializeAutogradZero(graph); GRAPH_DUMP("After specializeAutogradZero", graph); // runRequiredPasses { @@ -296,6 +302,8 @@ void ProfilingGraphExecutorImpl::runProfilingOptimizations( GRAPH_DUMP("Forward graph:", gradient.f); GRAPH_DUMP("Backward graph:", gradient.df); runDiffGraphPasses(gradient.f); + // replaces fallback graphs inserted by TE Fuser + replaceFallbackGraphWithFallbackFunction(gradient.f->block()); packGradient(gradient, dnode); GRAPH_DEBUG("Finished optimizing diff node ", idx++); } @@ -376,12 +384,21 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( std::lock_guard lock(compile_mutex); GRAPH_DEBUG("Running ProfilingGraphExecutorImpl ", this); + // if tensorExprFuserEnabled() returns true we need to persist the very first + // time ProfilingGraphExecutorImpl is called, so we can update it correctly + // for fallback functions in ProfilingGraphExecutorImpl Else, + // getPlanFor(remaining_bailout_depth) is corrected and persisted by the Code + // object in interpreter. + if (!remaining_bailout_depth_.has_value() || !tensorExprFuserEnabled()) { + remaining_bailout_depth_ = remaining_bailout_depth; + } + if (optimized_plan_) { return *optimized_plan_; } // simple executor - if (remaining_bailout_depth == 0) { + if (*remaining_bailout_depth_ == 0) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); GRAPH_DUMP("Optimized SimpleExecutor Graph : ", copy); @@ -393,13 +410,12 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( if (!pr_) { auto copy = graph->copy(); runProfilingInsensitiveOptimizations(copy); - if (remaining_bailout_depth == getBailoutDepth()) { + if (*remaining_bailout_depth_ == getBailoutDepth()) { PeelProfilingLoops(copy); } pr_ = ProfilingRecord::instrumentGraph(copy); - auto pr_copy = pr_->graph()->copy(); - GRAPH_DUMP("Profiled Graph: ", pr_copy); - profiling_plan_ = ExecutionPlan(pr_copy, function_name_); + GRAPH_DUMP("Profiled Graph: ", pr_->graph()); + profiling_plan_ = ExecutionPlan(pr_->graph(), function_name_); // fall-through } @@ -411,9 +427,13 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor( auto copy = pr_->graph()->copy(); ProfilingRecord::removeProfileCounter(copy->block()); runProfilingOptimizations(copy); + // replaces a fallback graph inserted by + // specialize_autogradzero if one exists + replaceFallbackGraphWithFallbackFunction(copy->block()); + GRAPH_DUMP("Optimized Graph: ", copy); // cache optimized_plan_ = - ExecutionPlan(copy, function_name_, remaining_bailout_depth); + ExecutionPlan(copy, function_name_, *remaining_bailout_depth_); return *optimized_plan_; } @@ -425,5 +445,96 @@ GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() { return state; } +void replaceBlockWithFallbackGraph(Block* b) { + auto graph = std::make_shared(); + auto value_map = [](Value* v) { return v; }; + graph->block()->cloneFrom(b, value_map); + auto fallback = b->owningGraph()->create( + prim::FallbackGraph, b->inputs(), b->outputs().size()); + fallback->g_(attr::Subgraph, graph); + b->prependNode(fallback); + + for (size_t i = 0; i < b->outputs().size(); i++) { + fallback->output(i)->setType(b->outputs()[i]->type()); + fallback->output(i)->copyMetadata(b->outputs()[i]); + b->replaceOutput(i, fallback->output(i)); + } + + for (auto it = b->nodes().rbegin(); it != fallback->iterator(); it++) { + it.destroyCurrent(); + } +} + +static Function* createFallbackPathFunction( + Block* b, + const std::string& function_name) { + auto value_map = [](Value* v) { return v; }; + auto graph = std::make_shared(); + graph->block()->cloneFrom(b, value_map); + + auto otypes = c10::fmap( + graph->return_node()->inputs(), [](Value* v) { return v->type(); }); + // a GraphFunction call only have one output, so all the outputs + // need to be packed into a tuple + auto tuple_type = TupleType::create(otypes); + auto return_tuple = graph->createTuple(graph->return_node()->inputs()); + graph->appendNode(return_tuple); + for (int i = static_cast(graph->outputs().size()) - 1; i >= 0; i--) { + graph->eraseOutput(i); + } + graph->registerOutput(return_tuple->output()); + return new GraphFunction(function_name, graph, nullptr); +} + +Node* insertFallbackFunctionCall( + Graph* graph, + Function* func, + ArrayRef inputs) { + auto tuple_type = func->graph()->return_node()->input(0)->type(); + Value* fn_constant = graph->insertNode(graph->create(prim::Constant)) + ->s_(attr::name, func->name()) + ->i_(Symbol::attr("fallback"), 1) + ->output() + ->setType(FunctionType::create(func)); + std::vector func_call_inputs = {fn_constant}; + func_call_inputs.insert(func_call_inputs.end(), inputs.begin(), inputs.end()); + Value* result = + graph->insertNode(graph->create(prim::CallFunction, func_call_inputs)) + ->output() + ->setType(tuple_type); + + auto fun_unpack_tuple = graph->insertNode(graph->createTupleUnpack(result)); + return fun_unpack_tuple; +} + +void ProfilingGraphExecutorImpl::replaceFallbackGraphWithFallbackFunction( + Block* b) { + Stack s; + for (auto it = b->nodes().begin(); it != b->nodes().end();) { + if (it->kind() == prim::FallbackGraph) { + auto fallback_func = createFallbackPathFunction( + it->g(attr::Subgraph)->block(), "fallback_function"); + TORCH_INTERNAL_ASSERT(*remaining_bailout_depth_ > 0); + GRAPH_DEBUG( + "getPlanFor for", getHeader(*it), " ", *remaining_bailout_depth_); + fallback_func->get_executor().getPlanFor( + s, *remaining_bailout_depth_ - 1); + fallback_functions_.emplace_back(fallback_func); + WithInsertPoint wip{*it}; + auto function_call = insertFallbackFunctionCall( + b->owningGraph(), fallback_func, it->inputs()); + for (size_t i = 0; i < function_call->outputs().size(); i++) { + it->output(i)->replaceAllUsesWith(function_call->output(i)); + } + it.destroyCurrent(); + } else { + for (Block* ib : it->blocks()) { + replaceFallbackGraphWithFallbackFunction(ib); + } + it++; + } + } +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h index e1e734942c12cc2..f2cf9cd4cd1e9ca 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h @@ -4,6 +4,13 @@ namespace torch { namespace jit { +TORCH_API Node* createFallbackGraph( + Block* b, + ArrayRef inputs, + Graph* g); + +TORCH_API void replaceBlockWithFallbackGraph(Block* b); + struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { ProfilingGraphExecutorImpl( const std::shared_ptr& graph, @@ -17,10 +24,21 @@ struct ProfilingGraphExecutorImpl : public GraphExecutorImplBase { private: void runProfilingInsensitiveOptimizations(std::shared_ptr& graph); void runProfilingOptimizations(std::shared_ptr& graph); + void replaceFallbackGraphWithFallbackFunction(Block* b); std::unique_ptr pr_; c10::optional profiling_plan_; // plan to run in order to profiling the code c10::optional optimized_plan_; + // fallback functions are inserted for tensorexpr fusion groups + // and by specialize_autogradzero. Whenever, at runtime, input + // tensor don't match profiled properties, fallback functions are called + // They are the deoptimized version of the logic in fusion groups + // and/or autograd. + // The fallback functions are owned by a GraphExecutor instance + // They only exist in the optimized graph which is a private property + // of the GraphExecutor and only shared with InterpreterState + std::vector> fallback_functions_; + c10::optional remaining_bailout_depth_; }; } // namespace jit diff --git a/torch/csrc/jit/runtime/profiling_record.cpp b/torch/csrc/jit/runtime/profiling_record.cpp index b9c9be2c54c9126..1ec50c9204ceefe 100644 --- a/torch/csrc/jit/runtime/profiling_record.cpp +++ b/torch/csrc/jit/runtime/profiling_record.cpp @@ -51,6 +51,19 @@ ProfileOp* ProfilingRecord::createProfileNode( return pn; } +ProfileOptionalOp* ProfilingRecord::createProfileOptionalNode( + const std::function& fp, + at::ArrayRef inputs) { + auto pn = new ProfileOptionalOp(profiled_graph_.get(), fp); + pn->i_(attr::num_present, 0); + pn->i_(attr::num_none, 0); + + for (auto in : inputs) { + pn->addInput(in); + } + return pn; +} + static void unprofileGraphInputs(const std::shared_ptr& graph) { for (auto i : graph->inputs()) { if (i->type()->isSubtypeOf(TensorType::get())) { @@ -205,6 +218,12 @@ void ProfilingRecord::removeProfileCounter(Block* b) { } } +bool hasGradSumToSizeUses(Value* v) { + return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) { + return use.user->kind() == aten::_grad_sum_to_size; + }); +} + void ProfilingRecord::instrumentBlock(Block* block) { for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { auto n = *it; @@ -214,6 +233,33 @@ void ProfilingRecord::instrumentBlock(Block* block) { (needsProfiledInputs(n) || needsProfiledOutput(i->node()))) { insertShapeProfile(n, offset); } + + if (i->type()->cast() && hasGradSumToSizeUses(i)) { + // here we are profile the definition instead of the use, + // because we are only optimizing in the case of a None value which is + // immutable + auto opt_pn = createProfileOptionalNode(nullptr, {i}); + std::function optional_profiler = [this, + opt_pn](Stack& stack) { + std::lock_guard lock(this->mutex_); + // frame_id is unused + int64_t frame_id = 0; + pop(stack, frame_id); + IValue value; + pop(stack, value); + if (value.isNone()) { + opt_pn->i_(attr::num_none, opt_pn->i(attr::num_none) + 1); + } else { + opt_pn->i_(attr::num_present, opt_pn->i(attr::num_present) + 1); + } + push(stack, value); + }; + opt_pn->setCallback(optional_profiler); + auto pno = opt_pn->addOutput(); + pno->setType(i->type()); + opt_pn->insertAfter(i->node()); + i->replaceAllUsesAfterNodeWith(opt_pn, pno); + } } for (auto b : n->blocks()) { diff --git a/torch/csrc/jit/runtime/profiling_record.h b/torch/csrc/jit/runtime/profiling_record.h index 7cadbfcfa920f1d..aa945c29c25272e 100644 --- a/torch/csrc/jit/runtime/profiling_record.h +++ b/torch/csrc/jit/runtime/profiling_record.h @@ -206,6 +206,9 @@ struct ProfilingRecord { ProfileOp* createProfileNode( const std::function& fp, at::ArrayRef inputs); + ProfileOptionalOp* createProfileOptionalNode( + const std::function& fp, + at::ArrayRef inputs); void instrumentBlock(Block* block); void insertShapeProfile(Node* n, size_t offset); ProfilingRecord(std::shared_ptr g); diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index 618485141544cd3..4e1a4fb1f211361 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -122,14 +122,7 @@ Operator createOperatorFromC10_withTracingHandledHere( jit::tracer::setTracingState(nullptr); } -#ifdef USE_STATIC_DISPATCH - { - at::AutoNonVariableTypeMode non_var_type_mode(true); - op.callBoxed(stack); - } -#else op.callBoxed(stack); -#endif // USE_STATIC_DISPATCH if (tracer_state) { jit::tracer::setTracingState(std::move(tracer_state)); diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index 4d7e79647e25e36..ae974c063ef358a 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -416,36 +416,38 @@ void listCopyAndSort(Stack* stack); void listSetItem(Stack* stack); -#define DEFINE_GENERIC_BINARY_OP(aten_op, op, result) \ - Operator( \ - #aten_op ".int_int(int a, int b) -> " #result, \ - [](Stack* stack) { \ - int64_t a, b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ - aliasAnalysisFromSchema()), \ - Operator( \ - #aten_op ".float_float(float a, float b) -> " #result, \ - [](Stack* stack) { \ - double a, b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ +#define DEFINE_GENERIC_BINARY_OP(aten_op, op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".int_int(int a, int b) -> " #result), \ + [](Stack* stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + #aten_op ".float_float(float a, float b) -> " #result), \ + [](Stack* stack) { \ + double a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) // define implementations for primitive number ops #define DEFINE_GENERIC_OP(aten_op, int_op, float_op, int_result, float_result) \ - Operator( \ - #aten_op ".int(int a, int b) -> " #int_result, \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> " #int_result), \ [](Stack* stack) { \ int64_t a, b; \ pop(stack, a, b); \ push(stack, int_op); \ }, \ aliasAnalysisFromSchema()), \ - Operator( \ - #aten_op ".float(float a, float b) -> " #float_result, \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + #aten_op ".float(float a, float b) -> " #float_result), \ [](Stack* stack) { \ double a, b; \ pop(stack, a, b); \ @@ -453,83 +455,86 @@ void listSetItem(Stack* stack); }, \ aliasAnalysisFromSchema()) -#define DEFINE_INT_FLOAT_OP(aten_op, op, result) \ - Operator( \ - #aten_op ".int_float(int a, float b) -> " #result, \ - [](Stack* stack) { \ - int64_t a; \ - double b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ - aliasAnalysisFromSchema()), \ - Operator( \ - #aten_op ".float_int(float a, int b) -> " #result, \ - [](Stack* stack) { \ - double a; \ - int64_t b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ +#define DEFINE_INT_FLOAT_OP(aten_op, op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op \ + ".int_float(int a, float b) -> " #result), \ + [](Stack* stack) { \ + int64_t a; \ + double b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op \ + ".float_int(float a, int b) -> " #result), \ + [](Stack* stack) { \ + double a; \ + int64_t b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_INT_OP(aten_op, op) \ - Operator( \ - #aten_op ".int(int a, int b) -> int", \ - [](Stack* stack) { \ - int64_t a, b; \ - pop(stack, a, b); \ - push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \ - }, \ +#define DEFINE_INT_OP(aten_op, op) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a, int b) -> int"), \ + [](Stack* stack) { \ + int64_t a, b; \ + pop(stack, a, b); \ + push(stack, op); /* NOLINT(hicpp-signed-bitwise) */ \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_STR_CMP_OP(aten_op, op) \ - Operator( \ - #aten_op ".str(str a, str b) -> bool", \ - [](Stack* stack) { \ - auto b = pop(stack).toStringRef(); \ - auto a = pop(stack).toStringRef(); \ - push(stack, op); \ - }, \ +#define DEFINE_STR_CMP_OP(aten_op, op) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".str(str a, str b) -> bool"), \ + [](Stack* stack) { \ + auto b = pop(stack).toStringRef(); \ + auto a = pop(stack).toStringRef(); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) // define a primitive op over Scalar operands. // it's necessary to register this overload following // int/float variations to avoid trapping Scalar args // in unintended implicit conversions -#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ - Operator( \ - #aten_op "(Scalar a, Scalar b) -> " #result, \ - [](Stack* stack) { \ - IValue x, y; \ - pop(stack, x, y); \ - if (x.isDouble()) { \ - if (y.isDouble()) { \ - double a = x.toDouble(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - double a = x.toDouble(); \ - int64_t b = y.toInt(); \ - push(stack, float_op); \ - } \ - } else { \ - if (y.isDouble()) { \ - int64_t a = x.toInt(); \ - double b = y.toDouble(); \ - push(stack, float_op); \ - } else { \ - int64_t a = x.toInt(); \ - int64_t b = y.toInt(); \ - push(stack, int_op); \ - } \ - } \ - }, \ +#define DEFINE_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op "(Scalar a, Scalar b) -> " #result), \ + [](Stack* stack) { \ + IValue x, y; \ + pop(stack, x, y); \ + if (x.isDouble()) { \ + if (y.isDouble()) { \ + double a = x.toDouble(); \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + double a = x.toDouble(); \ + int64_t b = y.toInt(); \ + push(stack, float_op); \ + } \ + } else { \ + if (y.isDouble()) { \ + int64_t a = x.toInt(); \ + double b = y.toDouble(); \ + push(stack, float_op); \ + } else { \ + int64_t a = x.toInt(); \ + int64_t b = y.toInt(); \ + push(stack, int_op); \ + } \ + } \ + }, \ aliasAnalysisFromSchema()) #define DEFINE_SCALAR_SCALAR_BINARY_OP(aten_op, int_op, float_op, result) \ - Operator( \ - #aten_op ".Scalar_Scalar(Scalar a, Scalar b) -> " #result, \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + #aten_op ".Scalar_Scalar(Scalar a, Scalar b) -> " #result), \ [](Stack* stack) { \ IValue x, y; \ pop(stack, x, y); \ @@ -573,60 +578,60 @@ void listSetItem(Stack* stack); DEFINE_SCALAR_BINARY_OP(aten_op, op, op, bool), \ DEFINE_STR_CMP_OP(aten_op, op) -#define DEFINE_UNARY_INT_OP(aten_op, op, result) \ - Operator( \ - #aten_op ".int(int a) -> " #result, \ - [](Stack* stack) { \ - int64_t a; \ - pop(stack, a); \ - push(stack, op); \ - }, \ +#define DEFINE_UNARY_INT_OP(aten_op, op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".int(int a) -> " #result), \ + [](Stack* stack) { \ + int64_t a; \ + pop(stack, a); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_UNARY_FLOAT_OP(aten_op, op, result) \ - Operator( \ - #aten_op ".float(float a) -> " #result, \ - [](Stack* stack) { \ - double a; \ - pop(stack, a); \ - push(stack, op); \ - }, \ +#define DEFINE_UNARY_FLOAT_OP(aten_op, op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".float(float a) -> " #result), \ + [](Stack* stack) { \ + double a; \ + pop(stack, a); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \ - DEFINE_UNARY_INT_OP(aten_op, op, int_result), \ - DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \ - Operator( \ - #aten_op ".Scalar(Scalar a) -> Scalar", \ - [](Stack* stack) { \ - IValue x; \ - pop(stack, x); \ - if (x.isDouble()) { \ - double a = x.toDouble(); \ - push(stack, static_cast(op)); \ - } else { \ - int64_t a = x.toInt(); \ - push(stack, static_cast(op)); \ - } \ - }, \ +#define DEFINE_UNARY_OP(aten_op, op, int_result, float_result) \ + DEFINE_UNARY_INT_OP(aten_op, op, int_result), \ + DEFINE_UNARY_FLOAT_OP(aten_op, op, float_result), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".Scalar(Scalar a) -> Scalar"), \ + [](Stack* stack) { \ + IValue x; \ + pop(stack, x); \ + if (x.isDouble()) { \ + double a = x.toDouble(); \ + push(stack, static_cast(op)); \ + } else { \ + int64_t a = x.toInt(); \ + push(stack, static_cast(op)); \ + } \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_BOOL_OP(aten_op, op) \ - Operator( \ - #aten_op ".bool(bool a, bool b) -> bool", \ - [](Stack* stack) { \ - bool a, b; \ - pop(stack, a, b); \ - push(stack, op); \ - }, \ +#define DEFINE_BOOL_OP(aten_op, op) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#aten_op ".bool(bool a, bool b) -> bool"), \ + [](Stack* stack) { \ + bool a, b; \ + pop(stack, a, b); \ + push(stack, op); \ + }, \ aliasAnalysisFromSchema()) -#define DEFINE_STRING_OP(op_name, string_op, result) \ - Operator( \ - #op_name ".str(str a, str b) ->" #result, \ - [](Stack* stack) { \ - auto b = pop(stack).toStringRef(); \ - auto a = pop(stack).toStringRef(); \ - push(stack, string_op); \ - }, \ +#define DEFINE_STRING_OP(op_name, string_op, result) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA(#op_name ".str(str a, str b) ->" #result), \ + [](Stack* stack) { \ + auto b = pop(stack).toStringRef(); \ + auto a = pop(stack).toStringRef(); \ + push(stack, string_op); \ + }, \ aliasAnalysisFromSchema()) } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index a67de455619d9b1..b97ead8afa6bcb0 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -128,16 +129,16 @@ TORCH_LIBRARY_IMPL(aten, CatchAll, m) { } RegisterOperators reg( - {Operator( - "aten::str(t elem) -> str", + {OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::str(t elem) -> str"), [](Stack* stack) { std::stringstream ss; ss << pop(stack); push(stack, ss.str()); }, aliasAnalysisFromSchema()), - Operator( - "aten::list(str t) -> str[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::list(str t) -> str[]"), [](Stack& stack) { auto str = pop(stack).toStringRef(); c10::List chars; @@ -150,8 +151,9 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), // only used internally in range() translation - Operator( - "aten::__range_length(int lo, int hi, int step) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__range_length(int lo, int hi, int step) -> int"), [](Stack& stack) { int64_t lo, hi, step; pop(stack, lo, hi, step); @@ -169,8 +171,9 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "aten::__derive_index(int index, int start, int step) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__derive_index(int index, int start, int step) -> int"), [](Stack& stack) { int64_t index, start, step; pop(stack, index, start, step); @@ -178,16 +181,16 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "prim::TupleUnpack(Any tup) -> ...", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::TupleUnpack(Any tup) -> ..."), [](Stack* stack) { tupleUnpack(*stack); }, aliasAnalysisSpecialCase()), - Operator( - "prim::unchecked_cast(t x) -> t", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::unchecked_cast(t x) -> t"), noop, aliasAnalysisSpecialCase()), - Operator( - "aten::IntImplicit(Tensor a) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::IntImplicit(Tensor a) -> int"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -195,8 +198,8 @@ RegisterOperators reg( push(stack, a.item()); }, aliasAnalysisFromSchema()), - Operator( - "aten::FloatImplicit(Tensor a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::FloatImplicit(Tensor a) -> float"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -204,8 +207,8 @@ RegisterOperators reg( push(stack, a.item()); }, aliasAnalysisFromSchema()), - Operator( - "aten::ScalarImplicit(Tensor a) -> Scalar", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::ScalarImplicit(Tensor a) -> Scalar"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -213,40 +216,40 @@ RegisterOperators reg( push(stack, a.item()); }, aliasAnalysisFromSchema()), - Operator( - "aten::Bool.Tensor(Tensor a) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Bool.Tensor(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); push(stack, a.is_nonzero()); }, aliasAnalysisFromSchema()), - Operator( - "aten::Bool.int(int a) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Bool.int(int a) -> bool"), [](Stack* stack) { int64_t i; pop(stack, i); push(stack, (bool)i); }, aliasAnalysisFromSchema()), - Operator( - "aten::Bool.float(float a) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Bool.float(float a) -> bool"), [](Stack* stack) { double d; pop(stack, d); push(stack, (bool)d); }, aliasAnalysisFromSchema()), - Operator( - "aten::Float.Tensor(Tensor a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Float.Tensor(Tensor a) -> float"), [](Stack* stack) { at::Tensor a; pop(stack, a); push(stack, a.item()); }, aliasAnalysisFromSchema()), - Operator( - "aten::Float.Scalar(Scalar a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Float.Scalar(Scalar a) -> float"), [](Stack* stack) { IValue scalar; pop(stack, scalar); @@ -257,24 +260,24 @@ RegisterOperators reg( } }, aliasAnalysisFromSchema()), - Operator( - "aten::Float.int(int a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Float.int(int a) -> float"), [](Stack* stack) { int64_t i; pop(stack, i); push(stack, (float)i); }, aliasAnalysisFromSchema()), - Operator( - "aten::Float.bool(bool a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Float.bool(bool a) -> float"), [](Stack* stack) { bool b; pop(stack, b); push(stack, (float)b); }, aliasAnalysisFromSchema()), - Operator( - "aten::Float.str(str a) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Float.str(str a) -> float"), [](Stack* stack) { auto s = pop(stack).toString(); std::string::size_type sz; @@ -289,68 +292,69 @@ RegisterOperators reg( } }, aliasAnalysisFromSchema()), - Operator( - "aten::format(str self, ...) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::format(str self, ...) -> str"), [](Stack* stack) { size_t num_inputs = pop(stack).toInt(); format(*stack, num_inputs); }, aliasAnalysisFromSchema()), - Operator( - "prim::NumToTensor.Scalar(Scalar a) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::NumToTensor.Scalar(Scalar a) -> Tensor"), [](Stack* stack) { at::Scalar s; pop(stack, s); push(stack, at::scalar_to_tensor(s)); }, aliasAnalysisFromSchema()), - Operator( - "prim::RaiseException(str msg) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::RaiseException(str msg) -> ()"), [](Stack* stack) { throw JITException(pop(stack).toStringRef()); }, aliasAnalysisFromSchema()), - Operator( - "aten::Size(int[] sizes) -> int[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Size(int[] sizes) -> int[]"), [](Stack* stack) {}, aliasAnalysisFromSchema()), - Operator( - "aten::size(Tensor self) -> int[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::size(Tensor self) -> int[]"), [](Stack* stack) { auto t = std::move(pop(stack)).toTensor(); pack(stack, t.sizes().vec()); }, aliasAnalysisFromSchema()), - Operator( - "prim::EnumName(AnyEnumType enum) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::EnumName(AnyEnumType enum) -> str"), [](Stack* stack) { IValue e = pop(stack); push(stack, e.toEnumHolder()->name()); }, aliasAnalysisFromSchema()), - Operator( - "prim::EnumValue.int(AnyEnumType enum) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::EnumValue.int(AnyEnumType enum) -> int"), [](Stack* stack) { IValue e = pop(stack); push(stack, e.toEnumHolder()->value()); }, aliasAnalysisFromSchema()), - Operator( - "prim::EnumValue.float(AnyEnumType enum) -> float", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "prim::EnumValue.float(AnyEnumType enum) -> float"), [](Stack* stack) { IValue e = pop(stack); push(stack, e.toEnumHolder()->value()); }, aliasAnalysisFromSchema()), - Operator( - "prim::EnumValue.str(AnyEnumType enum) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::EnumValue.str(AnyEnumType enum) -> str"), [](Stack* stack) { IValue e = pop(stack); push(stack, e.toEnumHolder()->value()); }, aliasAnalysisFromSchema()), - Operator( + OperatorGenerator( // note the compiler knows to type TupleIndex more accurately than it // is listed here. - "prim::TupleIndex(Any tup, int i) -> Any", + TORCH_SELECTIVE_SCHEMA("prim::TupleIndex(Any tup, int i) -> Any"), [](Stack* stack) { int64_t index = pop(stack).toInt(); auto tuple = pop(stack).toTuple(); @@ -362,69 +366,70 @@ RegisterOperators reg( stack->emplace_back(tuple->elements()[norm_index]); }, aliasAnalysisSpecialCase()), - Operator( - "aten::ne.int_list(int[] a, int[] b) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::ne.int_list(int[] a, int[] b) -> bool"), listNe, aliasAnalysisFromSchema()), - Operator( - "prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)"), noop, aliasAnalysisFromSchema()), - Operator( - "prim::device(Tensor a) -> Device", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::device(Tensor a) -> Device"), [](Stack* stack) { push(stack, pop(stack).toTensor().device()); }, aliasAnalysisFromSchema()), - Operator( - "prim::dtype(Tensor a) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::dtype(Tensor a) -> int"), [](Stack* stack) { at::Tensor a; pop(stack, a); push(stack, static_cast(a.scalar_type())); }, aliasAnalysisFromSchema()), - Operator( - "aten::__not__(bool self) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::__not__(bool self) -> bool"), [](Stack* stack) { push(stack, !pop(stack).toBool()); }, aliasAnalysisFromSchema()), - Operator( - "aten::__is__(t1 self, t2 obj) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::__is__(t1 self, t2 obj) -> bool"), [](Stack* stack) { IValue self, obj; pop(stack, self, obj); push(stack, self.is(obj)); }, aliasAnalysisFromSchema()), - Operator( - "aten::__isnot__(t1 self, t2 obj) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::__isnot__(t1 self, t2 obj) -> bool"), [](Stack* stack) { IValue self, obj; pop(stack, self, obj); push(stack, !self.is(obj)); }, aliasAnalysisFromSchema()), - Operator( - "aten::element_size(Tensor self) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::element_size(Tensor self) -> int"), [](Stack* stack) { at::Tensor arg = pop(stack).toTensor(); push(stack, arg.element_size()); }, aliasAnalysisFromSchema()), - Operator( - "aten::numel(Tensor self) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::numel(Tensor self) -> int"), [](Stack* stack) { at::Tensor arg = pop(stack).toTensor(); push(stack, arg.numel()); }, aliasAnalysisFromSchema()), - Operator( - "aten::dim(Tensor self) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"), [](Stack* stack) { at::Tensor arg = pop(stack).toTensor(); push(stack, arg.dim()); }, aliasAnalysisFromSchema()), - Operator( - "aten::get_device(Tensor self) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), [](Stack* stack) { RECORD_FUNCTION("get_device", std::vector()); auto result = @@ -433,8 +438,8 @@ RegisterOperators reg( pack(stack, result); }, aliasAnalysisFromSchema()), - Operator( - "aten::storage_offset(Tensor self) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::storage_offset(Tensor self) -> int"), [](Stack* stack) { RECORD_FUNCTION("storage_offset", std::vector()); auto result = @@ -443,8 +448,8 @@ RegisterOperators reg( pack(stack, result); }, aliasAnalysisFromSchema()), - Operator( - "aten::is_contiguous(Tensor self) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::is_contiguous(Tensor self) -> bool"), [](Stack* stack) { RECORD_FUNCTION("is_contiguous", std::vector()); auto result = @@ -455,89 +460,99 @@ RegisterOperators reg( aliasAnalysisFromSchema()), // these ops are generic over the list element type. // CREATING GENERIC_LIST_OPS - Operator( - "aten::select.t(t[](a) list, int idx) -> t(*)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::select.t(t[](a) list, int idx) -> t(*)"), listSelect, aliasAnalysisFromSchema()), - Operator( - "aten::__getitem__.t(t[](a) list, int idx) -> t(*)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__getitem__.t(t[](a) list, int idx) -> t(*)"), listSelect, aliasAnalysisFromSchema()), - Operator( - "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::append.t(t[](a!) self, t(c -> *) el) -> t[](a!)"), listAppend, aliasAnalysisFromSchema()), - Operator( - "aten::reverse.t(t[](a!) self) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::reverse.t(t[](a!) self) -> ()"), listReverse, aliasAnalysisFromSchema()), - Operator( - "aten::extend.t(t[](a!) self, t[] other) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::extend.t(t[](a!) self, t[] other) -> ()"), listExtend, aliasAnalysisFromSchema()), - Operator( - "aten::copy.t(t[](a) self) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::copy.t(t[](a) self) -> t[]"), listCopy, aliasAnalysisFromSchema()), - Operator( - "aten::_set_item.t(t [](a!) l, int idx, t(b -> *) el) -> t[](a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_set_item.t(t [](a!) l, int idx, t(b -> *) el) -> t[](a!)"), listSetItem, aliasAnalysisFromSchema()), - Operator( - "aten::clear.t(t[](a!) self) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::clear.t(t[](a!) self) -> ()"), listClear, aliasAnalysisFromSchema()), - Operator( - "aten::Delete.t(t[](a!) self, int idx) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::Delete.t(t[](a!) self, int idx) -> ()"), listDelete, aliasAnalysisFromSchema()), - Operator( - "aten::insert.t(t[](a!) self, int idx, t(b -> *) el) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::insert.t(t[](a!) self, int idx, t(b -> *) el) -> ()"), listInsert, aliasAnalysisFromSchema()), - Operator( - "aten::pop.t(t[](a!) self, int idx=-1) -> t(*)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::pop.t(t[](a!) self, int idx=-1) -> t(*)"), listPop, aliasAnalysisFromSchema()), - Operator( - "aten::add.t(t[] a, t[] b) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::add.t(t[] a, t[] b) -> t[]"), listAdd, aliasAnalysisFromSchema()), - Operator( - "aten::add_.t(t[](a!) self, t[] b) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::add_.t(t[](a!) self, t[] b) -> t[]"), listInplaceAdd, aliasAnalysisFromSchema()), - Operator( - "aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> t[]"), listSlice, aliasAnalysisFromSchema()), - Operator( - "aten::list.t(t[] l) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::list.t(t[] l) -> t[]"), listList, aliasAnalysisFromSchema()), - Operator( - "aten::mul.left_t(t[] l, int n) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::mul.left_t(t[] l, int n) -> t[]"), listMulIntLeft, aliasAnalysisFromSchema()), - Operator( - "aten::mul.right_(int n, t[] l) -> t[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::mul.right_(int n, t[] l) -> t[]"), listMulIntRight, aliasAnalysisFromSchema()), - Operator( - "aten::mul_.t(t[](a!) l, int n) -> t[](a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::mul_.t(t[](a!) l, int n) -> t[](a!)"), listMulIntLeftInPlace, aliasAnalysisFromSchema()), - Operator("aten::len.t(t[] a) -> int", listLen, aliasAnalysisFromSchema()), - Operator( - "aten::eq.int_list(int[] a, int[] b) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::len.t(t[] a) -> int"), + listLen, + aliasAnalysisFromSchema()), + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::eq.int_list(int[] a, int[] b) -> bool"), listEq, aliasAnalysisFromSchema()), - Operator( - "prim::Uninitialized() -> Any", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::Uninitialized() -> Any"), [](Stack* stack) { push(stack, IValue::uninitialized()); }, aliasAnalysisSpecialCase()), - Operator( - "prim::Print(...) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::Print(...) -> ()"), [](Stack* stack) { auto num_inputs = pop(stack).toInt(); std::stringstream ss; @@ -555,32 +570,35 @@ RegisterOperators reg( handler(ss.str()); }, aliasAnalysisSpecialCase()), - Operator( - "aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::eq.enum(AnyEnumType a, AnyEnumType b) -> bool"), [](Stack* stack) { IValue x = pop(stack); IValue y = pop(stack); push(stack, x == y); }, aliasAnalysisFromSchema()), - Operator( - "aten::ne.enum(AnyEnumType a, AnyEnumType b) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::ne.enum(AnyEnumType a, AnyEnumType b) -> bool"), [](Stack* stack) { IValue x = pop(stack); IValue y = pop(stack); push(stack, x != y); }, aliasAnalysisFromSchema()), - Operator( - "aten::dequantize.tensor(Tensor qtensor) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::dequantize.tensor(Tensor qtensor) -> Tensor"), [](Stack* stack) { at::Tensor qtensor; pop(stack, qtensor); push(stack, at::dequantize(qtensor)); }, aliasAnalysisFromSchema()), - Operator( - "aten::dequantize.any(Any tensors) -> Any", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::dequantize.any(Any tensors) -> Any"), [](Stack* stack) { dequantize(*stack); }, aliasAnalysisFromSchema()), DEFINE_STRING_OP(aten::add, a + b, str), @@ -654,8 +672,8 @@ RegisterOperators reg( static_cast(pow(a, b)), static_cast(pow(a, b)), float), - Operator( - "aten::pow.int_to_int(int a, int b) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::pow.int_to_int(int a, int b) -> int"), [](Stack* stack) { int64_t a, b; pop(stack, a, b); @@ -666,8 +684,8 @@ RegisterOperators reg( // the python builtin 'min' and 'torch.min' DEFINE_BINARY_OP(prim::min, a < b ? a : b), DEFINE_BINARY_OP(prim::max, a > b ? a : b), - Operator( - "prim::type(Device self) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::type(Device self) -> str"), [](Stack* stack) { auto d = pop(stack); push( @@ -676,8 +694,8 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), // tensor length op (size of 1st dimension) - Operator( - "aten::len.Tensor(Tensor t) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::len.Tensor(Tensor t) -> int"), [](Stack* stack) { at::Tensor t = pop(stack).toTensor(); if (t.dim() == 0) { @@ -686,8 +704,8 @@ RegisterOperators reg( push(stack, t.sizes()[0]); }, aliasAnalysisFromSchema()), - Operator( - "aten::ord(str string) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::ord(str string) -> int"), [](Stack& stack) { auto string = pop(stack).toStringRef(); TORCH_CHECK( @@ -699,8 +717,8 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "aten::lower(str self) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::lower(str self) -> str"), [](Stack& stack) { auto string = pop(stack).toStringRef(); std::stringstream ss; @@ -711,20 +729,22 @@ RegisterOperators reg( return 0; }, aliasAnalysisFromSchema()), - Operator( - "aten::__contains__.str_list(str[] l, str item) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__contains__.str_list(str[] l, str item) -> bool"), listContains, aliasAnalysisFromSchema()), - Operator( - "aten::len.str(str s) -> int", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::len.str(str s) -> int"), [](Stack& stack) { auto string = pop(stack).toStringRef(); push(stack, static_cast(string.size())); return 0; }, aliasAnalysisFromSchema()), - Operator( - "aten::__getitem__.str(str s, int index) -> str", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::__getitem__.str(str s, int index) -> str"), [](Stack& stack) { auto index = pop(stack).toInt(); auto string = pop(stack).toStringRef(); @@ -735,9 +755,10 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), #define CREATE_COPY_OP(other_type, c_type) \ - Operator( \ - "aten::copy_." #other_type "(Tensor(a!) self, " #other_type \ - " other) -> Tensor(a!)", \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::copy_." #other_type \ + "(Tensor(a!) self, " #other_type \ + " other) -> Tensor(a!)"), \ [](Stack* stack) { \ at::Tensor t; \ c_type other; \ @@ -751,8 +772,9 @@ RegisterOperators reg( CREATE_COPY_OP(int, int64_t), CREATE_COPY_OP(float, double), #undef CREATE_COPY_OP - Operator( - "aten::backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::backward(Tensor self, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()"), [](Stack* stack) { bool create_graph = pop(stack).toBool(); auto retain_graph = pop(stack).toOptional(); @@ -770,8 +792,9 @@ RegisterOperators reg( // and nullability scrubbed from TensorList arg types // TOOD find out why this exists and how to do it without the hack // - Operator( - "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"), [](Stack* stack) { auto indices = pop(stack).toTensorVector(); auto self = pop(stack).toTensor(); @@ -779,8 +802,9 @@ RegisterOperators reg( push(stack, std::move(result)); }, aliasAnalysisFromSchema()), - Operator( - "aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_index_put_impl_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)"), [](Stack* stack) { auto unsafe = pop(stack).toBool(); auto accumulate = pop(stack).toBool(); @@ -792,8 +816,9 @@ RegisterOperators reg( push(stack, std::move(result)); }, aliasAnalysisFromSchema()), - Operator( - "aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::index_put_.hacked_twin(Tensor(a!) self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)"), [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); @@ -803,8 +828,9 @@ RegisterOperators reg( push(stack, std::move(result)); }, aliasAnalysisFromSchema()), - Operator( - "aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::index_put.hacked_twin(Tensor self, Tensor[] indices, Tensor values, bool accumulate=False) -> Tensor"), [](Stack* stack) { auto accumulate = pop(stack).toBool(); auto values = pop(stack).toTensor(); @@ -815,8 +841,9 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), // reference function parse_to_conversion in python_arg_parsing.h - Operator( - "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack* stack) { bool non_blocking; bool copy; @@ -831,8 +858,9 @@ RegisterOperators reg( to_dispatch(self, device, scalarType, non_blocking, copy)); }, aliasAnalysisFromSchema()), - Operator( - "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::to.prim_dtype(Tensor(a) self, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"), [](Stack* stack) { bool non_blocking; bool copy; @@ -846,16 +874,16 @@ RegisterOperators reg( to_dispatch(self, device, scalarType, non_blocking, copy)); }, aliasAnalysisFromSchema()), - Operator( - "prim::is_cuda(Tensor a) -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::is_cuda(Tensor a) -> bool"), [](Stack* stack) { at::Tensor a; pop(stack, a); push(stack, a.is_cuda()); }, aliasAnalysisFromSchema()), - Operator( - "prim::data(Tensor(a) a) -> Tensor(a)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("prim::data(Tensor(a) a) -> Tensor(a)"), [](Stack* stack) { at::Tensor a; pop(stack, a); @@ -863,24 +891,27 @@ RegisterOperators reg( }, aliasAnalysisFromSchema()), // these ops are not defined for Tensor -#define CREATE_COMPARATOR_LIST_OPS_SPECIALIZED(decl_type, value_type) \ - Operator( \ - "prim::min." decl_type "_list(" decl_type "[] l, " decl_type \ - "[] r) -> " decl_type "[]", \ - minList, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "prim::max." decl_type "_list(" decl_type "[] l, " decl_type \ - "[] r) -> " decl_type "[]", \ - maxList, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "prim::min.self_" decl_type "(" decl_type "[] self) -> " decl_type, \ - listMin, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "prim::max.self_" decl_type "(" decl_type "[] self) -> " decl_type, \ - listMax, \ +#define CREATE_COMPARATOR_LIST_OPS_SPECIALIZED(decl_type, value_type) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("prim::min." decl_type "_list(" decl_type \ + "[] l, " decl_type "[] r) -> " decl_type "[]"), \ + minList, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("prim::max." decl_type "_list(" decl_type \ + "[] l, " decl_type "[] r) -> " decl_type \ + "[]"), \ + maxList, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("prim::min.self_" decl_type "(" decl_type \ + "[] self) -> " decl_type), \ + listMin, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("prim::max.self_" decl_type "(" decl_type \ + "[] self) -> " decl_type), \ + listMax, \ aliasAnalysisFromSchema()), CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("int", int64_t) CREATE_COMPARATOR_LIST_OPS_SPECIALIZED("float", double) @@ -1055,94 +1086,106 @@ void dictConstructFromList(Stack* stack) { push(stack, dict); } -#define CREATE_DICT_OPS(key_type) \ - Operator( \ - "aten::len.Dict_" key_type "(Dict(" key_type ", t) self) -> int", \ - dictLen, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::keys." key_type "(Dict(" key_type ", t) self) -> " key_type \ - "[](*)", \ - dictKeys, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::values." key_type "(Dict(" key_type ", t) self) -> t[](*)", \ - dictValues, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::__getitem__.Dict_" key_type "(Dict(" key_type \ - ", t) self, " key_type " key) -> t(*)", \ - dictIndex, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::get." key_type "(Dict(" key_type ", t) self, " key_type \ - " key) -> t(*)?", \ - dictGet, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::get.default_" key_type "(Dict(" key_type \ - ", t) self, " key_type " key, t default_value) -> t(*)", \ - dictGet, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::setdefault." key_type "(Dict(" key_type \ - ", t)(a!) self, " key_type \ - "(b -> *) key, t(c -> *) default_value) -> t(*)", \ - dictSetDefault, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::Delete.Dict_" key_type "(Dict(" key_type \ - ", t)(a!) self, " key_type " key) -> ()", \ - dictDelete, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::pop.Dict_" key_type "(Dict(" key_type \ - ", t)(a!) self, " key_type " key) -> t(*)", \ - dictPop, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::pop.Dict_default_" key_type "(Dict(" key_type \ - ", t)(a!) self, " key_type " key, t default_value) -> t(*)", \ - dictPop, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::popitem." key_type "(Dict(" key_type \ - ", t)(a!) self) -> ((" key_type ", t))", \ - dictPopItem, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::clear." key_type "(Dict(" key_type ", t)(a!) self) -> ()", \ - dictClear, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::update." key_type "(Dict(" key_type \ - ", t)(a!) self, Dict(" key_type ", t)(a!) to_add) -> ()", \ - dictUpdate, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::items." key_type "(Dict(" key_type \ - ", t) self) -> ((" key_type ", t)[])", \ - dictItems, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::copy.Dict_" key_type "(Dict(" key_type \ - ", t)(a) self) -> Dict(" key_type ", t)", \ - dictCopy, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::__contains__." key_type "(Dict(" key_type \ - ", t) dict, " key_type " key) -> bool", \ - dictContains, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::_set_item." key_type "(Dict(" key_type \ - ", t)(a!) l, " key_type "(b -> *) idx, t(c -> *) v) -> ()", \ - dictSetItem, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::dict." key_type "((" key_type \ - ", tVal)[] inputs) -> Dict(" key_type ", tVal)", \ - dictConstructFromList, \ +#define CREATE_DICT_OPS(key_type) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::len.Dict_" key_type "(Dict(" key_type \ + ", t) self) -> int"), \ + dictLen, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::keys." key_type "(Dict(" key_type \ + ", t) self) -> " key_type "[](*)"), \ + dictKeys, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::values." key_type "(Dict(" key_type \ + ", t) self) -> t[](*)"), \ + dictValues, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::__getitem__.Dict_" key_type \ + "(Dict(" key_type ", t) self, " key_type \ + " key) -> t(*)"), \ + dictIndex, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::get." key_type "(Dict(" key_type \ + ", t) self, " key_type " key) -> t(*)?"), \ + dictGet, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::get.default_" key_type \ + "(Dict(" key_type ", t) self, " key_type \ + " key, t default_value) -> t(*)"), \ + dictGet, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + "aten::setdefault." key_type "(Dict(" key_type \ + ", t)(a!) self, " key_type \ + "(b -> *) key, t(c -> *) default_value) -> t(*)"), \ + dictSetDefault, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::Delete.Dict_" key_type \ + "(Dict(" key_type ", t)(a!) self, " key_type \ + " key) -> ()"), \ + dictDelete, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::pop.Dict_" key_type "(Dict(" key_type \ + ", t)(a!) self, " key_type " key) -> t(*)"), \ + dictPop, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::pop.Dict_default_" key_type \ + "(Dict(" key_type ", t)(a!) self, " key_type \ + " key, t default_value) -> t(*)"), \ + dictPop, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::popitem." key_type "(Dict(" key_type \ + ", t)(a!) self) -> ((" key_type ", t))"), \ + dictPopItem, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::clear." key_type "(Dict(" key_type \ + ", t)(a!) self) -> ()"), \ + dictClear, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::update." key_type "(Dict(" key_type \ + ", t)(a!) self, Dict(" key_type \ + ", t)(a!) to_add) -> ()"), \ + dictUpdate, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::items." key_type "(Dict(" key_type \ + ", t) self) -> ((" key_type ", t)[])"), \ + dictItems, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::copy.Dict_" key_type "(Dict(" key_type \ + ", t)(a) self) -> Dict(" key_type ", t)"), \ + dictCopy, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::__contains__." key_type \ + "(Dict(" key_type ", t) dict, " key_type \ + " key) -> bool"), \ + dictContains, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::_set_item." key_type "(Dict(" key_type \ + ", t)(a!) l, " key_type \ + "(b -> *) idx, t(c -> *) v) -> ()"), \ + dictSetItem, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA("aten::dict." key_type "((" key_type \ + ", tVal)[] inputs) -> Dict(" key_type \ + ", tVal)"), \ + dictConstructFromList, \ aliasAnalysisFromSchema()) RegisterOperators reg_dict_ops({ diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 249d39995bf3191..8922c521af347e7 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -35,6 +35,16 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + Operator( + prim::profile_optional, + [](const Node* node) -> Operation { + auto callback = node->cast()->getCallback(); + return [](Stack* stack) { + AT_ERROR( + "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT + }; + }, + aliasAnalysisSpecialCase()), Operator( prim::FusionGroup, [](const Node* node) -> Operation { @@ -53,6 +63,15 @@ RegisterOperators reg( }; }, aliasAnalysisSpecialCase()), + Operator( + prim::FallbackGraph, + [](const Node* node) -> Operation { + return [](Stack* stack) { + AT_ERROR( + "Must be converted to prim::FunctionCall by replaceFallbackGraphWithFallbackFunction"); // NOLINT + }; + }, + aliasAnalysisSpecialCase()), Operator( "prim::Guard(Tensor(a) t) -> Tensor(a)", [](Stack* stack) { AT_ERROR("Should be replaced by prim::BailOut"); }, diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 1a82cb7694874a1..0de2d359c18373a 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -26,6 +27,10 @@ c10::AliasAnalysisKind aliasAnalysisFromSchema() { return c10::AliasAnalysisKind::FROM_SCHEMA; } +c10::AliasAnalysisKind aliasAnalysisConservative() { + return c10::AliasAnalysisKind::CONSERVATIVE; +} + void checkListInputType(const c10::TypePtr& elem_type, bool empty_list) { if (!elem_type->isSubtypeOf(NumberType::get()) && elem_type != BoolType::get()) { @@ -235,8 +240,9 @@ void createTensorFromList(Stack* stack) { } RegisterOperators reg({ - Operator( - "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::split(Tensor self, int[] split_sizes, int dim=0) -> Tensor[]"), [](Stack* stack) { RECORD_FUNCTION("split_with_sizes", last(stack, 3)); @@ -249,35 +255,37 @@ RegisterOperators reg({ }, aliasAnalysisFromSchema()), -#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \ - Operator( \ - "aten::tensor." #operator_type "(" #operator_type \ - " t, *, ScalarType? dtype=None, Device? device=None" \ - ", bool requires_grad=False) -> Tensor", \ - [](Stack* stack) { \ - c_type scalar_val; \ - IValue dtype; \ - IValue device; \ - bool requires_grad; \ - pop(stack, scalar_val, dtype, device, requires_grad); \ - auto tensor = tensor_creation_op; \ - tensor = castTensorTo(tensor, dtype, device); \ - tensor.set_requires_grad(requires_grad); \ - push(stack, std::move(tensor)); \ - }, \ - aliasAnalysisFromSchema()), \ - Operator( \ - "aten::as_tensor." #operator_type "(" #operator_type \ - " t, *, ScalarType? dtype=None, Device? device=None) -> Tensor", \ - [](Stack* stack) { \ - c_type scalar_val; \ - IValue dtype; \ - IValue device; \ - pop(stack, scalar_val, dtype, device); \ - auto tensor = tensor_creation_op; \ - tensor = castTensorTo(tensor, dtype, device); \ - push(stack, std::move(tensor)); \ - }, \ +#define DEFINE_TORCH_TENSOR_OP(operator_type, c_type, tensor_creation_op) \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + "aten::tensor." #operator_type "(" #operator_type \ + " t, *, ScalarType? dtype=None, Device? device=None" \ + ", bool requires_grad=False) -> Tensor"), \ + [](Stack* stack) { \ + c_type scalar_val; \ + IValue dtype; \ + IValue device; \ + bool requires_grad; \ + pop(stack, scalar_val, dtype, device, requires_grad); \ + auto tensor = tensor_creation_op; \ + tensor = castTensorTo(tensor, dtype, device); \ + tensor.set_requires_grad(requires_grad); \ + push(stack, std::move(tensor)); \ + }, \ + aliasAnalysisFromSchema()), \ + OperatorGenerator( \ + TORCH_SELECTIVE_SCHEMA( \ + "aten::as_tensor." #operator_type "(" #operator_type \ + " t, *, ScalarType? dtype=None, Device? device=None) -> Tensor"), \ + [](Stack* stack) { \ + c_type scalar_val; \ + IValue dtype; \ + IValue device; \ + pop(stack, scalar_val, dtype, device); \ + auto tensor = tensor_creation_op; \ + tensor = castTensorTo(tensor, dtype, device); \ + push(stack, std::move(tensor)); \ + }, \ aliasAnalysisFromSchema()), DEFINE_TORCH_TENSOR_OP( @@ -294,16 +302,17 @@ RegisterOperators reg({ // reference python implementation: internal_new_from_data in // tensor_new.cpp - Operator( - "aten::_infer_size(int[] a, int[] b) -> int[]", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::_infer_size(int[] a, int[] b) -> int[]"), [](Stack* stack) { auto a = pop(stack); auto b = pop(stack); push(stack, at::infer_size(a.toIntVector(), b.toIntVector())); }, aliasAnalysisFromSchema()), - Operator( - "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor"), [](Stack* stack) { at::Tensor weight; at::Tensor input; @@ -319,12 +328,14 @@ RegisterOperators reg({ push(stack, std::move(result)); }, aliasAnalysisFromSchema()), - Operator( - "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::tensor(t[] data, *, ScalarType? dtype=None, Device? device=None, bool requires_grad=False) -> Tensor"), createTensorFromList, aliasAnalysisFromSchema()), - Operator( - "aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(a|b)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::as_tensor(Tensor(a) data, *, ScalarType? dtype=None, Device? device=None) -> Tensor(a|b)"), [](Stack* stack) { auto device = pop(stack).toOptional(); auto dtype = pop(stack).toOptional(); @@ -340,25 +351,28 @@ RegisterOperators reg({ push(stack, std::move(data)); }, aliasAnalysisFromSchema()), - Operator( - "aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::as_tensor.list(t[] data, *, ScalarType? dtype=None, Device? device=None) -> Tensor"), createTensorFromList, aliasAnalysisFromSchema()), - Operator( - "aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, " - "Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_pack_sequence(Tensor output, Tensor batch_sizes, Tensor? sorted_indices, " + "Tensor? unsorted_indices) -> (Tensor, Tensor, Tensor?, Tensor?)"), [](Stack* stack) {}, aliasAnalysisFromSchema()), - Operator( - "aten::_get_tracing_state() -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::_get_tracing_state() -> bool"), [](Stack* stack) { push(stack, false); }, aliasAnalysisFromSchema()), - Operator( - "aten::is_scripting() -> bool", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA("aten::is_scripting() -> bool"), [](Stack* stack) { push(stack, true); }, aliasAnalysisFromSchema()), - Operator( - "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"), [](Stack* stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -370,8 +384,9 @@ RegisterOperators reg({ push(stack, tensor.uniform_(a, b)); }, aliasAnalysisFromSchema()), - Operator( - "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)"), [](Stack* stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -383,8 +398,9 @@ RegisterOperators reg({ push(stack, tensor.normal_(mean, std)); }, aliasAnalysisFromSchema()), - Operator( - "aten::_no_grad_fill_(Tensor(a!) tensor, float val) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_no_grad_fill_(Tensor(a!) tensor, float val) -> Tensor(a!)"), [](Stack* stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -395,8 +411,9 @@ RegisterOperators reg({ push(stack, at::fill_(tensor, val)); }, aliasAnalysisFromSchema()), - Operator( - "aten::_no_grad_zero_(Tensor(a!) tensor) -> Tensor(a!)", + OperatorGenerator( + TORCH_SELECTIVE_SCHEMA( + "aten::_no_grad_zero_(Tensor(a!) tensor) -> Tensor(a!)"), [](Stack* stack) { // TODO: remove when script supports setting grad mode torch::NoGradGuard no_grad; @@ -406,7 +423,17 @@ RegisterOperators reg({ push(stack, at::zero_(tensor)); }, aliasAnalysisFromSchema()), - + Operator( + "aten::is_grad_enabled() -> bool", + [](Stack* stack) { push(stack, torch::GradMode::is_enabled()); }, + aliasAnalysisConservative()), + Operator( + "aten::set_grad_enabled(bool val) -> ()", + [](Stack* stack) { + torch::GradMode::set_enabled(pop(stack).toBool()); + push(stack, IValue()); + }, + aliasAnalysisConservative()), }); } // namespace } // namespace jit diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 5067065c3ce097c..337b9bdef440971 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -4,6 +4,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -95,23 +96,122 @@ StaticRuntime::StaticRuntime(const torch::jit::Module& m) %r = static::mul(%y, %s) return (%r))IR"); sr.runOnGraph(graph_); - code_ = std::make_unique(graph_, ""); - interp_ = std::make_unique(*code_); + + // remove unused input 0 from graph + if (graph_->inputs().at(0)->type()->is_module()) { + if (!graph_->inputs().at(0)->hasUses()) { + graph_->eraseInput(0); + } + } + + // fill constant_table_ and operator_table_ + for (Node* node : graph_->nodes()) { + switch (node->kind()) { + case prim::Constant: + CHECK(node->output()->type()->kind() != FunctionType::Kind); + constant_table_[node->output()] = toIValue(node->output()).value(); + break; + case prim::ListConstruct: + nodes_.emplace_back(node, nullptr); + break; + case prim::TupleConstruct: + nodes_.emplace_back(node, nullptr); + break; + default: { + const Operator& op = node->getOperator(); + CHECK(op.hasOperation()); + nodes_.emplace_back(node, op.getOperation(node)); + } + } + } +} + +void StaticRuntime::getInputIValues( + Node* node, + const ConstantMap& ws, + std::vector& stack) const { + const size_t size = node->inputs().size(); + stack.reserve(size); + for (size_t i = 0; i < size; i++) { + Value* v = node->inputs()[i]; + auto f = constant_table_.find(v); + if (f == constant_table_.end()) { + auto f_ws = ws.find(v); + TORCH_CHECK( + f_ws != ws.end(), + "Workspace does not contain Value ", + v->debugName()); + stack.emplace_back(f_ws->second); + } else { + stack.emplace_back(f->second); + } + } +} + +void StaticRuntime::runNodes(ConstantMap& workspace) const { + std::vector stack; + for (const auto& p : nodes_) { + Node* node = p.first; + const Operation& op = p.second; + getInputIValues(node, workspace, stack); + VLOG(1) << node->kind().toDisplayString(); + + switch (node->kind()) { + case prim::ListConstruct: { + listConstruct( + stack, + node->output()->type()->expect(), + node->inputs().size()); + } break; + case prim::TupleConstruct: { + bool named = + node->output()->type()->expect()->name().has_value(); + if (named) { + namedTupleConstruct( + stack, + node->output()->type()->expect(), + node->inputs().size()); + } else { + tupleConstruct(stack, node->inputs().size()); + } + } break; + default: { + DCHECK(op); + op(&stack); + break; + } + } + + DCHECK_EQ(stack.size(), node->outputs().size()); + for (auto i = 0; i < node->outputs().size(); i++) { + workspace[node->outputs()[i]] = stack[i]; + } + stack.clear(); + } } std::vector StaticRuntime::run( const std::vector& inps) const { - std::vector stack; - if (graph_->inputs().at(0)->type()->is_module()) { - stack.emplace_back(module_._ivalue()); + // Container for inputs, outputs, and activations (excluding parameters) + ConstantMap workspace_; + + int start = 0; + if (graph_->inputs().size() != inps.size()) { + start = 1; + CHECK_EQ(graph_->inputs().size(), inps.size() + 1); + CHECK((graph_->inputs().at(0)->type()->is_module())); + workspace_.emplace(graph_->inputs()[0], module_._ivalue()); } - for (const auto& inp : inps) { - stack.emplace_back(inp); + + for (size_t i = 0; i < inps.size(); i++) { + workspace_.emplace(graph_->inputs()[i + start], inps[i]); } - interp_->run(stack); + runNodes(workspace_); + std::vector out; - for (const auto& v : stack) { + for (Value* output : graph_->outputs()) { + const IValue& v = workspace_[output]; if (v.isTuple()) { auto t = v.toTuple(); for (const auto& el : t->elements()) { diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index b3c8b0eee703218..c9b5dfd6a62663e 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -7,6 +7,10 @@ #include #include +#ifdef FBCODE_CAFFE2 +#include +#endif + namespace torch { namespace jit { @@ -19,13 +23,28 @@ class TORCH_API StaticRuntime { std::vector run(const std::vector& inps) const; +#ifdef FBCODE_CAFFE2 + using ConstantMap = folly::F14FastMap; +#else + using ConstantMap = std::unordered_map; +#endif + private: torch::jit::Module module_; std::shared_ptr graph_; - // Jit interpreter state - std::unique_ptr code_; - std::unique_ptr interp_; + // Static runtime states + // Constant table (including weights) + ConstantMap constant_table_; + // The nodes we need to run + std::vector> nodes_; + + void getInputIValues( + Node* node, + const ConstantMap& ws, + std::vector& stack) const; + + void runNodes(ConstantMap& ws_) const; }; } // namespace jit diff --git a/torch/csrc/jit/tensorexpr/analysis.h b/torch/csrc/jit/tensorexpr/analysis.h index 10a821e4908a03e..bf983ed67ce6b0b 100644 --- a/torch/csrc/jit/tensorexpr/analysis.h +++ b/torch/csrc/jit/tensorexpr/analysis.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace torch { namespace jit { @@ -67,6 +68,31 @@ class VarFinder : public IRVisitor { std::unordered_set vars_; }; +// A class that analyzes the given program relevant for Block backend +// It creates a map of multi dim buffers and their flat verions +class CreateBufferMap : public IRVisitor { + public: + const std::unordered_map& getBufferMap() const { + return map_input_to_tensor_bufs_; + } + + private: + void visit(const Store* v) override { + auto load_node = dynamic_cast(v->value()); + auto call_node = dynamic_cast(v->value()); + if (load_node || call_node) { + TORCH_INTERNAL_ASSERT(!(load_node && call_node)); + auto t_buf = load_node ? load_node->buf() : call_node->tensor()->buf(); + if (load_node) { + map_input_to_tensor_bufs_.emplace(t_buf->name_hint(), v->buf()); + } else { + map_input_to_tensor_bufs_.emplace(v->buf()->name_hint(), t_buf); + } + } + v->value()->accept(this); + } + std::unordered_map map_input_to_tensor_bufs_; +}; } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/block_codegen.cpp b/torch/csrc/jit/tensorexpr/block_codegen.cpp new file mode 100644 index 000000000000000..3116a2c8d6a5861 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/block_codegen.cpp @@ -0,0 +1,374 @@ +#include + +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +DEFINE_TRIGGER(block_codegen_created); +std::string blockDtypeCppString(const Dtype& dtype) { + switch (dtype.scalar_type()) { + case ScalarType::Bool: + return "1"; + case ScalarType::Half: + return "2"; + case ScalarType::Char: + return "1"; + case ScalarType::Byte: + return "1"; + case ScalarType::Short: + return "4"; + case ScalarType::Long: + return "8"; + case ScalarType::Float: + return "2"; // Return Half for now + default: + return dtype.ToCppString(); + } +} + +bool BlockAnalysis::areBufsInMap( + const std::unordered_set& bufs) const { + for (auto const& arg : bufs) { + auto got = map_input_to_tensor_bufs_.find(arg->name_hint()); + if (got == map_input_to_tensor_bufs_.end()) { + return false; + } + } + return true; +} + +const Buf* BlockAnalysis::getMultiDimBuf(const Buf* buf) const { + auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint()); + if (input_ != map_input_to_tensor_bufs_.end()) { + return input_->second; + } else { + throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map"); + } +} + +std::string BlockAnalysis::getInputName(const Buf* buf) const { + auto input_ = map_input_to_tensor_bufs_.find(buf->name_hint()); + if (input_ != map_input_to_tensor_bufs_.end()) { + return input_->second->name_hint(); + } else { + throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map"); + } +} + +void BlockAnalysis::visit(const Store* v) { + store_targets_.insert(v->buf()); + v->value()->accept(this); +} + +void BlockAnalysis::visit(const Load* v) { + loads_.insert(v->buf()); +} + +void BlockAnalysis::visit(const For* v) { + const LoopOptions& loop_options = v->loop_options(); + if (loop_options.is_gpu_block_index()) { + map_input_to_tensor_bufs_ = loop_options.get_buffer_mapping(); + v->body()->accept(this); + } else if (loop_options.is_gpu_thread_index()) { + auto block_size = v->stop(); + block_size_ = dynamic_cast(block_size)->value(); + v->body()->accept(this); + } else { + IRVisitor::visit(v); + } +} + +// For both Add, Mul we only print out the opening +// paranthesis. This behavior is to handle blocks add Op +// where c=a+b becomes add(a, b, c). The closing paran is +// added in the store statement. +// TODO: When handling fused ops d = a + b + c, the correct +// way would be to mutate the expression to Block version and print. + +void BlockPrinter::visit(const Add* v) { + emitIndent(); + os() << "add("; + v->lhs()->accept(this); + v->rhs()->accept(this); +} + +void BlockPrinter::visit(const Mul* v) { + emitIndent(); + os() << "mul("; + v->lhs()->accept(this); + v->rhs()->accept(this); +} + +void BlockPrinter::visit(const For* v) { + const LoopOptions& loop_options = v->loop_options(); + + auto buf_reads = block_analysis_->loads(); + auto buf_writes = block_analysis_->stores(); + std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); + bufs.insert(buf_writes.begin(), buf_writes.end()); + + if (loop_options.is_gpu_block_index()) { + emitIndent(); + PrintTensorInfo(bufs); + PrintDistribution(bufs); + PrintBufferInfo(buf_reads); + PrintArguments(bufs); + + emitIndent(); + os() << "compute {" << std::endl; + + PrintReshapeInfo(bufs); + + emitIndent(); + PrintLoop(bufs, true); + v->body()->accept(this); + + os() << std::endl; + emitIndent(); + PrintReshapeInfo(buf_writes, true); // print reverse reshape + os() << "}"; + os() << std::endl; + } else if (loop_options.is_gpu_thread_index()) { + PrintDMAs(buf_reads); + PrintLoop(buf_reads, false); + v->body()->accept(this); + os() << std::endl; + PrintAdjustBuffers(buf_reads); + + } else { + IRPrinter::visit(v); + } +} + +void BlockPrinter::PrintTensorInfo(const std::unordered_set& bufs) { + os() << "tensors {"; + for (const auto& buf : bufs) { + os() << std::endl; + emitIndent(); + emitIndent(); + auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size(); + os() << block_analysis_->getInputName(buf) << " = "; + os() << "{"; + for (unsigned long d = 0; d < num_dims; d++) { + os() << "{" << dim_names[d] << "};"; + } + os() << " elem : " << blockDtypeCppString(buf->dtype()); + os() << "}"; + } + + for (const auto& buf : bufs) { + os() << std::endl; + emitIndent(); + emitIndent(); + auto num_dims = block_analysis_->getMultiDimBuf(buf)->dims().size(); + os() << block_analysis_->getFlatInputName(buf) << " = "; + os() << "{"; + os() << "{" << flat_dim_names[num_dims - 1] << "};"; + os() << " elem : " << blockDtypeCppString(buf->dtype()); + os() << "}" + << " // flattened tensor"; + } + os() << std::endl; + emitIndent(); + os() << "}" << std::endl << std::endl; +} + +void BlockPrinter::PrintArguments(const std::unordered_set& bufs) { + for (const auto& buf : bufs) { + auto multidimbuf = block_analysis_->getMultiDimBuf(buf); + auto num_dims = multidimbuf->dims().size(); + + // The dims for the multi-dim tensors + for (unsigned long d = 0; d < num_dims; d++) { + auto dim_val = dynamic_cast(multidimbuf->dim(d)); + this->dim_values_map.emplace(this->dim_names[d], dim_val->value()); + } + + // The dimensions for the flattened tensors + auto val = dynamic_cast(buf->dim(0)); + if (block_analysis_->is_buf_store_target(buf)) { + this->dim_values_map.emplace( + this->flat_dim_names[num_dims - 1], val->value()); + } + } + + emitIndent(); + os() << "arguments {" << std::endl; + + for (auto const& arg : this->dim_values_map) { + emitIndent(); + os() << "var " << arg.first << " = " << arg.second << std::endl; + } + + emitIndent(); + emitIndent(); + auto blck_sz = block_analysis_->block_size(); + os() << "var bs_N = " << blck_sz << std::endl; + emitIndent(); + emitIndent(); + os() << "var bs_DPE = " << blck_sz << std::endl; + emitIndent(); + os() << "}" << std::endl << std::endl; +} + +void BlockPrinter::PrintBufferInfo(const std::unordered_set& bufs) { + emitIndent(); + os() << "buffers {"; + for (const auto& read : bufs) { + os() << std::endl; + emitIndent(); + emitIndent(); + os() << block_analysis_->getFlatInputName(read) << " = "; + os() << "{{" + << "bs_DPE" + << "}}"; + } + os() << std::endl; + emitIndent(); + os() << "}" << std::endl << std::endl; +} + +void BlockPrinter::PrintDistribution( + const std::unordered_set& bufs) { + emitIndent(); + os() << "distribution {" << std::endl; + for (const auto& buf : bufs) { + emitIndent(); + emitIndent(); + auto buf_name = buf->name_hint(); + os() << block_analysis_->getFlatInputName(buf) << " = "; + os() << "{(0, 1, )}" << std::endl; + } + os() << " }" << std::endl << std::endl; +} + +void BlockPrinter::PrintLoop( + const std::unordered_set& bufs, + bool block_idx) { + emitIndent(); + os() << "loop ("; + auto trip = 0; + for (const auto& buf : bufs) { + if (trip > 0) { + os() << ","; + } + os() << "{dim : "; + os() << block_analysis_->getFlatInputName(buf) << ".dim.0, "; + os() << (block_idx ? "block: bs_N}" : "block: bs_DPE}"); + ++trip; + } + os() << ")"; +} + +void BlockPrinter::PrintReshapeInfo( + const std::unordered_set& bufs, + bool reverse) { + for (const auto& buf : bufs) { + emitIndent(); + os() << "reshape(" + << (reverse ? block_analysis_->getFlatInputName(buf) + : block_analysis_->getInputName(buf)) + << ", " + << (reverse ? block_analysis_->getInputName(buf) + : block_analysis_->getFlatInputName(buf)) + << ")" << std::endl; + } +} + +void BlockPrinter::PrintDMAs(const std::unordered_set& bufs) { + for (const auto& read : bufs) { + emitIndent(); + os() << "dma_in("; + os() << block_analysis_->getFlatInputName(read); + os() << ")" << std::endl; + } +} +void BlockPrinter::PrintAdjustBuffers( + const std::unordered_set& bufs) { + for (const auto& read : bufs) { + emitIndent(); + os() << "adjust_buffer("; + os() << block_analysis_->getFlatInputName(read); + os() << ")" << std::endl; + } +} + +void BlockPrinter::visit(const Load* v) { + os() << block_analysis_->getFlatInputName(v->buf()) << ".buffer, "; +} +void BlockPrinter::visit(const Store* v) { + emitIndent(); + os() << *v->value() << block_analysis_->getFlatInputName(v->buf()) + << ".tensor)" << std::endl; +} + +void BlockPrinter::visit(const Block* v) { + os() << "{" << std::endl; + indent_++; + for (Stmt* s : v->stmts()) { + s->accept(this); + } + indent_--; + emitIndent(); + os() << "}"; +} + +std::string BlockCodeGen::GetUniqueFuncName(const std::string& func_prefix) { + // We are using a global counter here to make sure difference instances + // within BlockCodeGen have different names. + static int64_t counter = 0; + ++counter; + int64_t value = counter; + return func_prefix + "_" + c10::to_string(value); +} + +void BlockCodeGen::Initialize() { + block_analysis_ = std::make_unique(); + printer_ = std::make_unique(&oss_, block_analysis_.get()); + + Stmt* stmt_v = stmt(); + stmt_v->accept(block_analysis_.get()); + + auto buf_reads = block_analysis_->loads(); + auto buf_writes = block_analysis_->stores(); + // Ensure all Bufs in reads/writes are in the map + std::unordered_set bufs(buf_reads.begin(), buf_reads.end()); + bufs.insert(buf_writes.begin(), buf_writes.end()); + if (!block_analysis_->areBufsInMap(bufs)) { + throw std::runtime_error("BlockCodeGen: Entry not in input/Buffer map"); + }; + + std::string func_name = GetUniqueFuncName("func"); + os() << "kernel " << func_name << "("; + for (auto const& arg : buf_writes) { + os() << block_analysis_->getInputName(arg); + } + for (auto const& arg : buf_reads) { + os() << ";" << block_analysis_->getInputName(arg); + } + os() << ")"; + + stmt_v->accept(printer_.get()); + + GRAPH_DEBUG("Generated Block code: ", oss_.str(), "\n"); + + USE_TRIGGER(block_codegen_created); +} + +void BlockCodeGen::call(const std::vector& args) { + throw std::runtime_error("BlockCodeGen: Cannot call Block code "); +} + +BlockCodeGen::~BlockCodeGen() = default; +RegisterCodeGen block_codegen_reg("block_codegen"); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/block_codegen.h b/torch/csrc/jit/tensorexpr/block_codegen.h new file mode 100644 index 000000000000000..fcd88e040e176e2 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/block_codegen.h @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +// A class that analyzes the given program relevant for Block backend. +class BlockAnalysis : public IRVisitor { + public: + bool is_buf_store_target(const Buf* buf) const { + return store_targets_.count(buf) > 0; + } + + const std::unordered_set& loads() const { + return loads_; + } + + const std::unordered_set& stores() const { + return store_targets_; + } + + int block_size() const { + return block_size_; + } + + bool areBufsInMap(const std::unordered_set& bufs) const; + + const Buf* getMultiDimBuf(const Buf* buf) const; + + std::string getInputName(const Buf* buf) const; + + std::string getFlatInputName(const Buf* buf) const { + return getInputName(buf) + "_flat"; + } + + std::unordered_map getBufferMap() const { + return map_input_to_tensor_bufs_; + } + + private: + void visit(const Store* v) override; + void visit(const Load* v) override; + void visit(const For* v) override; + + std::unordered_map map_input_to_tensor_bufs_; + std::unordered_set store_targets_; + std::unordered_set loads_; + int block_size_ = 32; +}; + +// A class that overrides the underlying IRPrinter to produce Block. +class BlockPrinter : public IRPrinter { + public: + BlockPrinter(std::ostream* os, const BlockAnalysis* block_analysis) + : IRPrinter(*os), block_analysis_(block_analysis) {} + + using IRPrinter::name_manager; + using IRPrinter::visit; + + private: + const BlockAnalysis* block_analysis_; + std::unordered_map dim_values_map; + std::vector dim_names = {"N", "H", "W", "C"}; + std::vector flat_dim_names = {"N", "NH", "NHW", "NHWC"}; + void PrintTensorInfo(const std::unordered_set& bufs); + void PrintArguments(const std::unordered_set& bufs); + void PrintBufferInfo(const std::unordered_set& bufs); + void PrintDistribution(const std::unordered_set& bufs); + void PrintLoop( + const std::unordered_set& bufs, + bool block_idx = true); + void PrintReshapeInfo( + const std::unordered_set& bufs, + bool reverse = false); + void PrintDMAs(const std::unordered_set& bufs); + void PrintAdjustBuffers(const std::unordered_set& bufs); + + void visit(const For* v) override; + void visit(const Load* v) override; + void visit(const Store* v) override; + void visit(const Block* v) override; + void visit(const Add* v) override; + void visit(const Mul* v) override; +}; + +class TORCH_API BlockCodeGen : public CodeGen { + public: + template + /* implicit */ + BlockCodeGen(Stmt* stmt, Ts... ts) + : CodeGen( + stmt, + std::vector({BufferArg(ts)...}), + at::Device(at::kCPU)) { + Initialize(); + } + + BlockCodeGen( + Stmt* stmt, + const std::vector& buffer_args, + at::Device device = at::Device(at::kCPU)) + : CodeGen(stmt, buffer_args, device) { + Initialize(); + } + + ~BlockCodeGen() override; + + void call(const std::vector& args) override; + + void Initialize(); + + std::string getCodeText() override { + return oss_.str(); + } + + private: + UniqueNameManager* name_manager() { + if (!printer_) { + throw std::runtime_error("Null IRPrinter is not expected"); + } + return printer_->name_manager(); + } + + std::ostream& os() { + return printer_->os(); + } + + std::ostringstream oss_; + std::unique_ptr printer_; + std::unique_ptr block_analysis_; + + std::string GetUniqueFuncName(const std::string& func_prefix); +}; +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h index 2f823f9ae4d8bb1..1d67ef2ab89b4ac 100644 --- a/torch/csrc/jit/tensorexpr/codegen.h +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -49,6 +49,14 @@ class TORCH_API CodeGen { return device_; } + // This function returns the generated code as + // a string. Currently only implemented for Block. + // TODO. Rename this, as we can return other than string + // and implement for other backends. + virtual std::string getCodeText() { + return (""); + } + virtual void call(const std::vector& args) = 0; private: diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.h b/torch/csrc/jit/tensorexpr/ir_simplifier.h index 4b6f695c9ca4f13..789a5b6906e1adf 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.h +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.h @@ -24,7 +24,6 @@ namespace tensorexpr { // A bunch of helpers for determine the Dtype of the output of a multi argument // Term or Polynomial. -namespace { template Dtype promoteTypesVec(const Expr* s, std::vector& v) { Dtype t = s->dtype(); @@ -86,7 +85,7 @@ Dtype promoteTypesVar(const ExprType* e, Args... es) { } // Creates a new Expr of the given type with the provided lhs and rhs. -static const Expr* newBinaryOpOfType( +inline const Expr* newBinaryOpOfType( IRNodeType expr_type, const Expr* lhs, const Expr* rhs, @@ -123,7 +122,7 @@ static const Expr* newBinaryOpOfType( // Uses the evaluator to fold an Expression with constant terms. // E.g. evaluateOp(Add(3, 4)) => 7. // Expr v must not have any unbound Vars. -static Expr* evaluateOp(const Expr* v) { +inline Expr* evaluateOp(const Expr* v) { ExprHandle handle(v); ExprEval eval(handle); @@ -142,8 +141,6 @@ static Expr* evaluateOp(const Expr* v) { return nullptr; } -} // namespace - // A Term represents a grouping of Exprs through multiplication. // E.g. product(scalar, *variables). class Term : public ExprNode { @@ -414,8 +411,7 @@ class TORCH_API PolynomialTransformer : public IRSimplifierBase { static const Expr* simplify(const Expr* e); static ExprHandle simplify(const ExprHandle& e); static Stmt* simplify(Stmt* e); - -}; // namespace tensorexpr +}; // Expands Terms and Polynomial expressions into primitive operations. // Does some simple factorization and reordering. diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 22d829932a1d0e5..fbf9d640f8b080c 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -18,6 +18,7 @@ static int te_cuda_pointwise_loop_levels = -1; static int te_cuda_pointwise_block_count = -1; static int te_cuda_pointwise_block_size = -1; static bool fallback_allowed = false; +static bool te_generate_block_code = false; bool setFallbackAllowed(bool value) { bool old_value = fallback_allowed; @@ -48,6 +49,13 @@ int& getTECudaPointwiseBlockSize() { return te_cuda_pointwise_block_size; } +// TODO: Remove this global var +// Ideally Block code gen should be decided +// based on device type in tensor. +bool& getTEGenerateBlockCode() { + return te_generate_block_code; +} + } // namespace tensorexpr } // namespace jit } // namespace torch @@ -1214,7 +1222,8 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) { } void TensorExprKernel::flattenTensors(BackendType backendType) { - if (backendType != BackendType::kCudaCodeGen) { + if (backendType != BackendType::kCudaCodeGen && + backendType != BackendType::kBlockCodeGen) { // We only need to flatten for GPU, for other backends just use the same // tensors. flatTensorOutputs_ = tensorOutputs_; @@ -1264,9 +1273,11 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { torch::jit::tensorexpr::LoopNest l(flatTensorOutputs_); GRAPH_DEBUG("Original Stmt:\n", std::to_string(l.root_stmt()), "\n"); + bool hasReduction = NodeFinder::find(l.root_stmt()).size() != 0; + // Compute non-output tensors_ inline for (auto& p : tensors_) { - if (!l.hasLoopBodyFor(p.second)) { + if (!l.hasLoopBodyFor(p.second) || hasReduction) { continue; } Stmt* loop = l.getLoopBodyFor(p.second); @@ -1323,12 +1334,36 @@ Stmt* TensorExprKernel::generateStmt(BackendType backendType) { } } - bool allowVectorization = - NodeFinder::find(l.root_stmt()).size() == 0; + if (backendType == kBlockCodeGen) { + auto block_analysis = std::make_unique(); + for (size_t i = 0; i < flatTensorOutputs_.size(); i++) { + const int default_fp16_blocksize = 16; + const int default_uint8_blocksize = 32; + int blockSize = default_fp16_blocksize; + // We only handle looplevels == 2 for now + Tensor* tensor = flatTensorOutputs_[i]; + // Run Block analysis to get multi dim buffer info + auto root_stmt = l.root_stmt(); + root_stmt->accept(block_analysis.get()); + + if (tensor->buf()->dtype().scalar_type() == ScalarType::Byte) { + blockSize = default_uint8_blocksize; + } + l.computeInline(l.getLoopBodyFor(tensorOutputs_[i])); + For* outer; + For* inner; + std::vector loops = l.getLoopStmtsFor(tensor); + TORCH_INTERNAL_ASSERT(loops.size() > 0, "loops should not be empty"); + l.splitWithMask(loops[0], blockSize, &outer, &inner); + l.setGPUBlockIndex(outer, 0); + l.setGPUThreadIndex(inner, 0); + l.setBufferMap(outer, block_analysis->getBufferMap()); + } + } l.prepareForCodegen(); - if (backendType == kLLVMCodeGen && allowVectorization) { + if (backendType == kLLVMCodeGen && !hasReduction) { std::vector innerLoops; std::vector worklist; @@ -1408,6 +1443,8 @@ std::string TensorExprKernel::getCodeGenName(BackendType backendType) { return "llvm_codegen"; case kSimpleIREval: return "simple_ir_eval"; + case kBlockCodeGen: + return "block_codegen"; default: throw std::runtime_error( "invalid backend type: " + @@ -1529,6 +1566,8 @@ TensorExprKernel::BackendType TensorExprKernel::inferBackendTypeFromDevice( BackendType backendType = BackendType::kUninitialized; if (device.type() == at::kCUDA) { backendType = kCudaCodeGen; + } else if (device.type() == at::kCPU && getTEGenerateBlockCode()) { + backendType = kBlockCodeGen; } else if (device.type() == at::kCPU) { #ifdef TORCH_ENABLE_LLVM backendType = kLLVMCodeGen; @@ -1658,15 +1697,12 @@ TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( // aten::sum takes the input tensor named self. auto sizes = sizesForValue(node->namedInput(attr::self)); const auto inputs = node->inputs(); + int rank = sizes.size(); if (inputs.size() > 2) { + auto nodeAxes = getReductionAxes(node); // Canonicalize axes: wrap around, sort and make unique. - auto axesValue = node->namedInput(attr::dim); - TORCH_INTERNAL_ASSERT(axesValue->node()->kind() == prim::ListConstruct); - for (auto axisNode : axesValue->node()->inputs()) { - int rank = sizes.size(); - int axis = at::maybe_wrap_dim( - constant(axisNode).AsNode()->value(), rank); - axes.push_back(axis); + for (auto axis : nodeAxes) { + axes.push_back(at::maybe_wrap_dim(axis, rank)); } std::sort(axes.begin(), axes.end()); axes.erase(std::unique(axes.begin(), axes.end()), axes.end()); @@ -1701,9 +1737,31 @@ TensorExprKernel::ReductionInfo TensorExprKernel::getReductionInfo( return {reductionDims, outputDims, axes, keepdim, dtype}; } +std::vector TensorExprKernel::getReductionAxes( + const torch::jit::Node* node) { + std::vector axes; + auto axesNode = node->namedInput(attr::dim)->node(); + // There are two possible representations for reduction axes: + // 1. A prim::ListConstruct of integer constants. + // 2. A prim::Constant list of integer ival's. + // We need to handle both of them. + if (axesNode->kind() == prim::ListConstruct) { + for (auto axisNode : axesNode->inputs()) { + axes.push_back(constant(axisNode).AsNode()->value()); + } + return axes; + } + TORCH_INTERNAL_ASSERT(axesNode->kind() == prim::Constant); + TORCH_INTERNAL_ASSERT(axesNode->kindOf(attr::value) == AttributeKind::ival); + const auto& genericList = axesNode->ival(attr::value).toList(); + for (const IValue axisNode : genericList) { + axes.push_back(axisNode.toInt()); + } + return axes; +} + void TensorExprKernel::compile() { KernelScope kernelScope(&kernelArena_); - // Bind inputs to buffers. nInputs_ = graph_->inputs().size(); for (auto const& input : graph_->inputs()) { @@ -1740,7 +1798,6 @@ void TensorExprKernel::compile() { device_ = pickDeviceType(graph_->inputs()); BackendType backendType = inferBackendTypeFromDevice(device_); Stmt* stmt = generateStmt(backendType); - // Set up formal params (inputs, then outputs) for kernel. std::vector params = prepareBufferArgs(); diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index e71d0f435515206..1da878089a31dd3 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -30,12 +31,17 @@ class TORCH_API TensorExprKernel { Stmt* getCodeGenStmt(); + std::string getCodeText() { + return codegen_->getCodeText(); + } + private: enum BackendType { kUninitialized, kSimpleIREval, kLLVMCodeGen, kCudaCodeGen, + kBlockCodeGen, }; void compile(); @@ -143,6 +149,9 @@ class TORCH_API TensorExprKernel { // Get the reduction info for the given node, based on properties and inputs. ReductionInfo getReductionInfo(const torch::jit::Node* node); + // Get the reduction axes for the given node, based on properties and inputs. + std::vector getReductionAxes(const torch::jit::Node* node); + private: struct ShapeArg { size_t idx; @@ -200,6 +209,7 @@ class TORCH_API TensorExprKernel { TORCH_API int& getTECudaPointwiseLoopLevels(); TORCH_API int& getTECudaPointwiseBlockCount(); TORCH_API int& getTECudaPointwiseBlockSize(); +TORCH_API bool& getTEGenerateBlockCode(); TORCH_API bool fallbackAllowed(); TORCH_API bool setFallbackAllowed(bool value); diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index dc51e757c2dd174..32de6b3b9bc0118 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -1258,6 +1258,12 @@ void LoopNest::setGPUThreadIndex(For* f, int thread_index) { f->set_gpu_thread_index(thread_index); } +void LoopNest::setBufferMap( + For* f, + const std::unordered_map& map) { + f->set_buffer_map(map); +} + Stmt* LoopNest::getLoopBodyFor(Tensor* t) const { return tensor_to_stmt_.at(t); } diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h index 13c9dff3952c802..6282e250367188b 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.h +++ b/torch/csrc/jit/tensorexpr/loopnest.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -44,6 +45,9 @@ class TORCH_API LoopNest { void setGPUBlockIndex(For* f, int idx); void setGPUThreadIndex(For* f, int idx); + void setBufferMap( + For* f, + const std::unordered_map& map); // Insert a temporary computation of statement S in the scope of loop AT. // S is assumed to be a Store or a Block containing a Store. Along with the diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h index 8372551460c7f0a..1f2358d203ed68c 100644 --- a/torch/csrc/jit/tensorexpr/reduction.h +++ b/torch/csrc/jit/tensorexpr/reduction.h @@ -81,10 +81,11 @@ class ReduceOp : public ExprNode { std::vector reduce_args_; }; -// A Reducer is a user interface describing a particular reduction operation. It -// has three components: An initializtion value, a way of interacting each value -// with the accumulation, and a method for obtaining the current value to be -// reduced. It is materialized into a ReduceOp when loop variables are known. +// A Reducer is a user interface describing a particular reduction +// operation. It has three components: An initialization value, a way of +// interacting each value with the accumulation, and a method for obtaining the +// current value to be reduced. It is materialized into a ReduceOp when loop +// variables are known. class Reducer { public: Reducer(ExprHandle init, ReduceInteraction& interaction) @@ -173,8 +174,7 @@ class Sum : public Reducer { }) {} }; -namespace { -ExprHandle maximumVal(ScalarType type) { +inline ExprHandle maximumVal(ScalarType type) { switch (type) { #define MAX_BY_TYPE_CASE(Type, Name) \ case ScalarType::Name: \ @@ -187,7 +187,7 @@ ExprHandle maximumVal(ScalarType type) { return ExprHandle(); } -static ExprHandle minimumVal(ScalarType type) { +inline ExprHandle minimumVal(ScalarType type) { switch (type) { #define MAX_BY_TYPE_CASE(Type, Name) \ case ScalarType::Name: \ @@ -198,7 +198,6 @@ static ExprHandle minimumVal(ScalarType type) { throw unsupported_dtype(); } } -} // namespace class Maximum : public Reducer { public: diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h index ea135833f2ba6e4..65ac29fab40cbc3 100644 --- a/torch/csrc/jit/tensorexpr/stmt.h +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -543,9 +543,19 @@ class TORCH_API LoopOptions { return gpu_block_index_ == IDX_UNSET && gpu_thread_index_ == IDX_UNSET; } + void set_buffer_mapping( + const std::unordered_map& map) { + map_input_to_tensor_bufs_ = map; + } + + std::unordered_map get_buffer_mapping() const { + return map_input_to_tensor_bufs_; + } + private: int gpu_block_index_{IDX_UNSET}; int gpu_thread_index_{IDX_UNSET}; + std::unordered_map map_input_to_tensor_bufs_; }; class TORCH_API For : public StmtNode { @@ -639,6 +649,10 @@ class TORCH_API For : public StmtNode { loop_options_.set_gpu_thread_index(thread_index); } + void set_buffer_map(const std::unordered_map& map) { + loop_options_.set_buffer_mapping(map); + } + For* cloneWithNewBody(Stmt* body) const { return new For(var_, start_, stop_, body, loop_options_); } diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h index cfb82662c60e7c6..81bcbee7865245d 100644 --- a/torch/csrc/jit/tensorexpr/tensor.h +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -99,9 +99,7 @@ TORCH_API Tensor* Compute( const std::vector& dim_args, const std::function&)>& body_func); -namespace { - -static inline void unpack_dim_args( +inline void unpack_dim_args( const std::vector& dim_args, std::vector* dims, std::vector* vars) { @@ -112,7 +110,6 @@ static inline void unpack_dim_args( vars->push_back(new Var(dim_arg.name_hint(), kInt)); } } -} // namespace // Handle reductions over a Reducer and a body_func which produces values. template diff --git a/torch/custom_class.h b/torch/custom_class.h index 62d498c274169bf..3805cfafc91ab5f 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -137,15 +137,19 @@ class class_ { /// /// Currently, both the `get_state` and `set_state` callables must be /// C++ lambda expressions. They should have the following signatures, - /// where `CurClass` is the class you're registering and `T` is some object + /// where `CurClass` is the class you're registering and `T1` is some object /// that encapsulates the state of the object. /// - /// __getstate__(intrusive_ptr) -> T - /// __setstate__(T) -> intrusive_ptr + /// __getstate__(intrusive_ptr) -> T1 + /// __setstate__(T2) -> intrusive_ptr /// - /// `T` must be an object that is convertable to IValue by the same rules + /// `T1` must be an object that is convertable to IValue by the same rules /// for custom op/method registration. /// + /// For the common case, T1 == T2. T1 can also be a subtype of T2. An + /// example where it makes sense for T1 and T2 to differ is if __setstate__ + /// handles legacy formats in a backwards compatible way. + /// /// Example: /// /// .def_pickle( @@ -207,16 +211,17 @@ class class_ { getstate_schema.returns().size() == 1, "__getstate__ should return exactly one value for serialization. Got: ", format_getstate_schema()); + auto ser_type = getstate_schema.returns().at(0).type(); auto setstate_schema = classTypePtr->getMethod("__setstate__").getSchema(); auto arg_type = setstate_schema.arguments().at(1).type(); TORCH_CHECK( - (*arg_type == *ser_type), - "__setstate__'s argument should be the same type as the " - "return value of __getstate__. Got ", - arg_type->repr_str(), + ser_type->isSubtypeOf(arg_type), + "__getstate__'s return type should be a subtype of " + "input argument of __setstate__. Got ", + ser_type->repr_str(), " but expected ", - ser_type->repr_str()); + arg_type->repr_str()); return *this; } diff --git a/torch/distributed/algorithms/ddp_comm_hooks/__init__.py b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py new file mode 100644 index 000000000000000..51678fe44590e6c --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/__init__.py @@ -0,0 +1,42 @@ +from enum import Enum +from functools import partial + +import torch.distributed.algorithms.ddp_comm_hooks.default_hooks as default +import torch.distributed.algorithms.ddp_comm_hooks.quantization_hooks as quantization +from torch.nn.parallel import DistributedDataParallel + + +def ddp_comm_hook_wrapper(comm_hook, model, state): + model._register_comm_hook(state, comm_hook) + + +class DDPCommHookType(Enum): + ''' + DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example, + you can register allreduce hook by + ``DDPCommHookType.ALLREDUCE.value(model=model, state=process_group)``. + ''' + ALLREDUCE = partial(ddp_comm_hook_wrapper, comm_hook=default.allreduce_hook) + FP16_COMPRESS = partial(ddp_comm_hook_wrapper, comm_hook=default.fp16_compress_hook) + QUANTIZE_PER_TENSOR = partial( + ddp_comm_hook_wrapper, comm_hook=quantization.quantization_pertensor_hook + ) + QUANTIZE_PER_CHANNEL = partial( + ddp_comm_hook_wrapper, comm_hook=quantization.quantization_perchannel_hook + ) + + +def register_ddp_comm_hook( + comm_hook_type: DDPCommHookType, model: DistributedDataParallel, state=None +): + """ + Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks`` + to the DDP model. User can specify the type of hook as an enum + ``DDPCommHookType`` type using ``comm_hook_type`` input. State input will + be passed to the model. + + Example:: + >>> register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, model, state) + """ + comm_hook_type.value(model=model, state=state) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py new file mode 100644 index 000000000000000..16638a915f706a0 --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -0,0 +1,118 @@ +import torch +import torch.distributed as dist + + +def allreduce_hook( + process_group: object, bucket: dist._GradBucket +) -> torch.futures.Future: + """ + This DDP communication hook just calls ``allreduce`` using ``GradBucket`` + tensors. Once gradient tensors are aggregated across all workers, its ``then`` + callback takes the mean and returns the result. If user registers this hook, + DDP results is expected to be same as the case where no hook was registered. + Hence, this won't change behavior of DDP and user can use this as a reference + or modify this hook to log useful information or any other purposes while + unaffecting DDP behavior. + + Example:: + >>> ddp_model._register_comm_hook(process_group, allreduce_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = ( + process_group.size() if process_group is not None else dist.get_world_size() + ) + + tensor = bucket.get_tensors()[0] + fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future() + + def then_callback(fut): + return [fut.value()[0].div_(world_size)] + + return fut.then(then_callback) + + +def fp16_compress_hook(process_group: object, bucket: dist._GradBucket): + """ + This DDP communication hook implements a simple gradient compression + approach that converts ``GradBucket`` tensors whose type is assumed to be + ``torch.float32`` to half-precision floating point format (``torch.float16``). + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, its then callback called ``decompress`` converts the + aggregated result back to ``float32`` and takes the mean. + + Example:: + >>> ddp_model._register_comm_hook(process_group, fp16_compress_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + world_size = ( + process_group.size() if process_group is not None else dist.get_world_size() + ) + + compressed_tensor = bucket.get_tensors()[0].to(torch.float16) + + fut = dist.all_reduce( + compressed_tensor, group=group_to_use, async_op=True + ).get_future() + + def decompress(fut): + return [fut.value()[0].to(torch.float32).div_(world_size)] + + return fut.then(decompress) + + +def _get_allgather_out_list(all_gather_in_list, world_size): + out_list = [ + torch.zeros_like( + all_gather_in_list, + device=all_gather_in_list.device, + dtype=all_gather_in_list.dtype, + ) + for _ in range(world_size) + ] + return out_list + + +def _allgather_then_aggregate_hook( + process_group: object, bucket: dist._GradBucket +) -> torch.futures.Future: + """ + Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors + and its ``then`` callback aggregates the gathered gradient tensors and takes + mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with + W workers, both the computation and communication time scale as O(W) for + allgather compared to O(logW) for allreduce. Therefore, this hook is expected + to be much slower than ``allreduce_hook`` although both essentially do the + same thing with the gradients. + + .. warning :: + This is for test and experiments. User is suggested to use a faster + alternative called ``allreduce_hook`` that uses ``allreduce`` protocol + instead of ``allgather`` protocol. + + Example:: + >>> ddp_model._register_comm_hook(process_group, allreduce_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = ( + process_group.size() if process_group is not None else dist.get_world_size() + ) + + tensor = bucket.get_tensors()[0] + fut = dist.all_gather( + _get_allgather_out_list(tensor, world_size), + tensor, + group=group_to_use, + async_op=True, + ).get_future() + + def aggregate(fut): + all_ranks_tensor = fut.value()[0] + tensor = bucket.get_tensors()[0] + for r, gathered_tensor in enumerate(all_ranks_tensor): + if r != rank: + tensor += gathered_tensor + + return [tensor.div_(world_size)] + + return fut.then(aggregate) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py new file mode 100644 index 000000000000000..afac1ee66873b1b --- /dev/null +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -0,0 +1,217 @@ +import torch +import torch.distributed as dist +from torch import nn + + +def _quantize_per_tensor_cuda(x, scale, zero_point): + y = torch.round(x / scale) + zero_point + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_tensor_cuda(y, scale, zero_point): + x = scale * (y.to(torch.float32) - zero_point) + return x + + +def _quantize_per_channel_cuda(x, scale, zero_point): + y = torch.zeros(x.size(), device=x.device) + for i in range(x.size()[0]): + y[i, :] = torch.round(x[i, :] / scale[i]) + zero_point[i] + y = torch.clamp(y, 0, 255).to(torch.uint8) + return y + + +def _dequantize_per_channel_cuda(y, scale, zero_point): + y = y.to(torch.float32).cuda(y.device) + x = torch.zeros_like(y, device=y.device) + for i in range(x.size()[0]): + x[i, :] = scale[i] * (y[i, :] - zero_point[i]) + return x + + +def _get_allgather_out_list(all_gather_in_list, world_size): + out_list = [ + torch.zeros_like( + all_gather_in_list, + device=all_gather_in_list.device, + dtype=all_gather_in_list.dtype, + ) + for _ in range(world_size) + ] + return out_list + + +def quantization_pertensor_hook( + process_group: object, bucket: dist._GradBucket +) -> torch.futures.Future: + """ + Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` + protocol. Workers first allgather the scale and zero point of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensors, and uses ``allgather`` to communicate these accross all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and + aggregates each quantized gradient tensors locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> ddp_model._register_comm_hook(process_group, quantization_pertensor_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = ( + process_group.size() if process_group is not None else dist.get_world_size() + ) + + tensor = bucket.get_tensors()[0] + + myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device) + myObserver(tensor) + + s, z = myObserver.calculate_qparams() + s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros accross all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their own ``GradBucket`` tensors. + quantized_tensor = _quantize_per_tensor_cuda( + tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1] + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_tensor_cuda( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return [aggregated_dequantized_tensor / world_size] + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) + + +def quantization_perchannel_hook( + process_group: object, bucket: dist._GradBucket, bucket_size=512 +) -> torch.futures.Future: + """ + Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather`` + protocol. Compared to pertensor, the main motivation of perchannel is + for considerably large tensors such as a tensor that contains 6 million + elements quantizing per a bucket size of 512 (or 128) elements may significantly + increase the resolution. + + It first splits ``GradBucket`` tensors into multiple chunks (channels) of ``bucket_size`` + elements. Then, workers allgather the scales and zero points of their own + ``GradBucket`` prior to the quantization. After all workers have that information, + the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's + own gradient tensors, and uses ``allgather`` to communicate these accross all workers. + The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and + aggregates each quantized gradient tensors locally and returns the mean. + + .. warning :: + This is experimental, and uses ``allgather`` protocol which is considerably slower than + ``allreduce`` protocol. It works only with flattened grads. + + Example:: + >>> ddp_model._register_comm_hook(process_group, quantization_perchannel_hook) + """ + group_to_use = process_group if process_group is not None else dist.group.WORLD + rank = process_group.rank() if process_group is not None else dist.get_rank() + world_size = ( + process_group.size() if process_group is not None else dist.get_world_size() + ) + + tensor = bucket.get_tensors()[0] + + tensor_in_channels = ( + nn.functional.pad( + input=tensor, + pad=(0, bucket_size - len(tensor) % bucket_size), + mode="constant", + value=0, + ) + .view(-1, bucket_size) + .cuda(tensor.device) + ) + + myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda( + tensor.device + ) + myPerChannelObserver(tensor_in_channels) + + s_ch, z_ch = myPerChannelObserver.calculate_qparams() + s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device) + + all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) + # First, allgather scale and zeros. + fut = dist.all_gather( + all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True + ).get_future() + + def quantize_and_allgather(fut): + # Store scale and zeros accross all workers. + all_ranks_s_and_z = fut.wait()[0] + # All workers quantize their corresponding ``GradBucket`` tensors. + quantized_tensor = _quantize_per_channel_cuda( + tensor_in_channels, + all_ranks_s_and_z[rank, 0, :], + all_ranks_s_and_z[rank, 1, :], + ) + # Allgather quantized tensors. + fut = dist.all_gather( + _get_allgather_out_list(quantized_tensor, world_size), + quantized_tensor, + group=group_to_use, + async_op=True, + ).get_future() + + return fut.wait() + + def dequantize_and_aggregate(fut): + all_ranks_quantized_tensor = fut.wait()[0] + + aggregated_dequantized_tensor = torch.zeros_like( + all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32 + ) + # Using previously allgathered scales and zeros, dequantize gradient tensors + # locally and then aggregate them. + for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): + aggregated_dequantized_tensor += _dequantize_per_channel_cuda( + quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1] + ) + + return [ + torch.flatten(aggregated_dequantized_tensor).cuda(tensor.device)[ + : tensor.size()[0] + ] + / world_size + ] + + return fut.then(quantize_and_allgather).then(dequantize_and_aggregate) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 95bc4f9f0448f7e..77b9c90a6316362 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -3,6 +3,10 @@ from typing import Callable, Any, List, Dict, Optional, Tuple import builtins import torch +import keyword + +def _shadows_builtin_name(name: str) -> bool: + return name in builtins.__dict__ or name in keyword.kwlist def _is_magic(x: str) -> bool: return x.startswith('__') and x.endswith('__') @@ -71,16 +75,16 @@ def add_use(n: Node): return n map_arg(a, add_use) - def create_node(self, op: str, target: Target, - args: Optional[Tuple[Argument, ...]] = None, - kwargs: Optional[Dict[str, Argument]] = None, + def create_node(self, op: str, target: Target, + args: Optional[Tuple[Argument, ...]] = None, + kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None): assert op in ('call_function', 'call_method', 'get_param', 'call_module', 'placeholder') args = () if args is None else args kwargs = {} if kwargs is None else kwargs self._mark_uses(args) self._mark_uses(kwargs) - n = Node(self, name if name is not None else self._name(target or op), op, target, args, kwargs) + n = Node(self, name if name is not None else self._name(target), op, target, args, kwargs) self.nodes.append(n) return n @@ -91,9 +95,12 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], Argument] = lamb kwargs = map_arg(node.kwargs, arg_transform) assert isinstance(args, tuple) assert isinstance(kwargs, dict) - return self.create_node( - node.op, node.target, args, kwargs, - self._name(node.name)) + if node.op == "placeholder": + # Placeholder names are user-visible, so they should be copied as-is without normalizing them. + name = node.name + else: + name = self._name(node.name) + return self.create_node(node.op, node.target, args, kwargs, name) def output(self, result: Argument): self.result = result @@ -112,17 +119,15 @@ def _name(self, target: Target) -> str: if op not in self._used_names: self._used_names[op] = 0 - if not hasattr(torch, op) and not hasattr(torch.nn.functional, op) and not hasattr(torch.nn, op): + # Avoid shadowing PyTorch and Python builtins. + if not hasattr(torch, op) and \ + not hasattr(torch.nn.functional, op) and \ + not hasattr(torch.nn, op) and \ + not _shadows_builtin_name(op): return op i = self._used_names[op] = self._used_names[op] + 1 return f'{op}_{i}' - def get_param(self, name: str) -> Node: - return self.create_node('get_param', name) - - def placeholder(self, name: str) -> Node: - return self.create_node('placeholder', target=name, name=name.replace('*', '')) - def python_code(self, root_module: str) -> Tuple[str, str, List[str]]: free_vars: List[str] = [] body: List[str] = [] diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 379e9a6600e4c01..f8db5439020def7 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -41,6 +41,7 @@ class GraphModuleImpl(cls): # type: ignore def __init__(self, root: torch.nn.Module, graph: Graph): super().__init__() self.root = root + self.training = self.root.training self.graph = graph self._generate_forward() diff --git a/torch/fx/symbolic_trace.py b/torch/fx/symbolic_trace.py index cb9571c1b34dd60..b86469159afaef9 100644 --- a/torch/fx/symbolic_trace.py +++ b/torch/fx/symbolic_trace.py @@ -47,28 +47,53 @@ class DelegateBase: def __init__(self, graph: Graph): self.graph = graph - # A method to insert a graph node given target, args, kwargs, and name. - # This method can be overridden to do extra checking, validation, or - # modification of values used in node creation. For example, one might - # want to disallow in-place operations from being recorded. def create_node(self, kind : str, target : Union[str, Callable], args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ return self.graph.create_node(kind, target, args, kwargs, name) - # A method to specify whether a given `nn.Module` is a "leaf" - # module. Leaf modules are the atomic units that appear in - # the IR, referenced by `call_module` calls. By default, - # Modules in the PyTorch standard library namespace (torch.nn) - # are leaf modules. All other modules are traced through and - # their constituent ops are recorded, unless specified otherwise - # via this parameter. + def placeholder(self, name): + """ + Inserts a new placeholder (i.e. graph input) + + This method can be overridden to do extra modification, e.g. attach more attributes to the node. + """ + return self.create_node('placeholder', target=name, args=(), kwargs={}, name=name.replace('*', '')) + + def get_param(self, target): + """ + Inserts a graph node representing access of the parameter with full qual name `target` + + This method can be overridden to do extra modification, e.g. attach more attributes to the node. + """ + return self.create_node('get_param', target, args=(), kwargs={}) + def is_leaf_module(self, m: torch.nn.Module) -> bool: + """ + A method to specify whether a given `nn.Module` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by `call_module` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + """ return m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential) - # A method that lowers the objects seen as arguments during symbolic evaluation - # into Argument types that can be stored in IR. - # Can be override to support more trace-specific types. def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ # aggregates if isinstance(a, (tuple, list)): return type(a)(self.create_arg(elem) for elem in a) @@ -104,14 +129,14 @@ def create_arg(self, a: Any) -> Argument: if isinstance(a, torch.nn.Parameter): for n, p in self.root.named_parameters(): if a is p: - return self.graph.get_param(n) + return self.get_param(n) raise NameError('parameter is not a member of this module') return super().create_arg(a) def _proxy_placeholder(name: str, delegate: DelegateBase) -> Proxy: - return Proxy(delegate.graph.placeholder(name), delegate) + return Proxy(delegate.placeholder(name), delegate) # Symbolic tracing API # diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 7ceca7f52759df3..fd61228f3379089 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -9,6 +9,7 @@ _overload, _overload_method, ignore, + is_scripting, export, unused, ) @@ -16,7 +17,6 @@ script, Attribute, ScriptModule, - is_scripting, script_method, RecursiveScriptModule, ScriptWarning, diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 11cb49f45125fd2..1a92ad660712f25 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -115,7 +115,7 @@ def _get_builtin_table(): def register_all(mod): for name in dir(mod): v = getattr(mod, name) - if callable(v) and not _is_special_functional_bound_op(v): + if callable(v) and not _is_special_functional_bound_op(v) and v is not torch.no_grad: _builtin_ops.append((v, "aten::" + name)) for mod in _modules_containing_builtins: register_all(mod) diff --git a/torch/jit/_script.py b/torch/jit/_script.py index de44387bfaae01b..f1b1d18cdc00c16 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -470,7 +470,7 @@ def code_with_constants(self): def save(self, *args, **kwargs): r""" - save(f, _extra_files=ExtraFilesMap{}) + save(f, _extra_files={}) See :func:`torch.jit.save ` for details. """ @@ -935,28 +935,6 @@ def forward(self, input): return fn -def is_scripting(): - r""" - Function that returns True when in compilation and False otherwise. This - is useful especially with the @unused decorator to leave code in your - model that is not yet TorchScript compatible. - .. testcode:: - - import torch - - @torch.jit.unused - def unsupported_linear_op(x): - return x - - def linear(x): - if not torch.jit.is_scripting(): - return torch.linear(x) - else: - return unsupported_linear_op(x) - """ - return False - - # overloads are registered in _jit_internal and compiled here so that _overload # can be used in nn/functional.py without an import cycle @@ -1068,9 +1046,6 @@ def _recursive_compile_class(obj, loc): _compile_and_register_class(obj, rcb, _qual_name) -_register_builtin(is_scripting, "aten::is_scripting") - - class CompilationUnit(object): def __init__(self, lang=None, _frames_up=0): self._c = torch._C.CompilationUnit() @@ -1095,3 +1070,4 @@ def _unwrap_optional(x): _register_builtin(_unwrap_optional, "aten::_unwrap_optional") +_register_builtin(_jit_internal.is_scripting, "aten::is_scripting") diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index 900bbd7e95a3f34..78c0d0e989c1018 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -15,10 +15,8 @@ from torch.jit._recursive import wrap_cpp_module from torch.serialization import validate_cuda_device -DEFAULT_EXTRA_FILES_MAP = torch._C.ExtraFilesMap() - -def save(m, f, _extra_files=DEFAULT_EXTRA_FILES_MAP): +def save(m, f, _extra_files=None): r""" Save an offline version of this module for use in a separate process. The saved module serializes all of the methods, submodules, parameters, and @@ -74,10 +72,11 @@ def forward(self, x): torch.jit.save(m, buffer) # Save with extra files - extra_files = torch._C.ExtraFilesMap() - extra_files['foo.txt'] = 'bar' + extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) """ + if _extra_files is None: + _extra_files = {} if isinstance(f, str) or isinstance(f, pathlib.Path): m.save(f, _extra_files=_extra_files) else: @@ -85,7 +84,7 @@ def forward(self, x): f.write(ret) -def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): +def load(f, map_location=None, _extra_files=None): r""" Load a :class:`ScriptModule` or :class:`ScriptFunction` previously saved with :func:`torch.jit.save ` @@ -133,8 +132,7 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): torch.jit.load(buffer, map_location='cpu') # Load with extra files. - extra_files = torch._C.ExtraFilesMap() - extra_files['foo.txt'] = 'bar' + extra_files = {'foo.txt': ''} # values will be replaced with data torch.jit.load('scriptmodule.pt', _extra_files=extra_files) print(extra_files['foo.txt']) @@ -155,6 +153,8 @@ def load(f, map_location=None, _extra_files=DEFAULT_EXTRA_FILES_MAP): raise ValueError("The provided filename {} is a directory".format(f)) map_location = validate_map_location(map_location) + if _extra_files is None: + _extra_files = {} cu = torch._C.CompilationUnit() if isinstance(f, str) or isinstance(f, pathlib.Path): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index afae5f5ae4fa3b2..92925a01b1045f5 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -304,13 +304,8 @@ class WithItemBuilder(Builder): def build_withitem(ctx, item): lineno = item.context_expr.lineno start = item.context_expr.col_offset + end = start + len(pretty_node_names[ast.With]) op_vars = item.optional_vars - - if op_vars: - end = op_vars.col_offset + len(op_vars.id) - else: - end = start + len(item.context_expr.id) - r = ctx.make_range(lineno, start, end) return WithItem(r, build_expr(ctx, item.context_expr), build_expr(ctx, op_vars) if op_vars else None) diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 9e15fa36268c616..dfae068de24401f 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -13,7 +13,7 @@ #include #include -#include +#include #ifdef USE_CUDA #include diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index a8d22ae53bba912..e6e8c595242ae25 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -409,6 +409,17 @@ ProcessGroupNCCL::ProcessGroupNCCL( std::string(NCCL_BLOCKING_WAIT)); } + // If single-process single-device mode, WorkNCCL::getFuture is supported, + // so get a dedicated stream for each device to run FutureNCCL then callbacks. + // Depending on the device index of collective outputs, WorkNCCL passes + // the corresponding device's then callback stream to FutureNCCL. + futureNCCLCallbackStreams_.reserve(c10::cuda::device_count()); + for (int device_index = 0; device_index < c10::cuda::device_count(); + device_index++) { + futureNCCLCallbackStreams_.push_back(std::make_shared( + at::cuda::getStreamFromPool(device_index))); + } + #ifdef ENABLE_NCCL_ERROR_CHECKING ncclCommWatchdogThread_ = std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); @@ -776,10 +787,14 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: TORCH_INTERNAL_ASSERT( outputs_->size() == 1, "WorkNCCL's getFuture API is only supported for single-process single-device mode."); + auto deviceIndex = (*outputs_)[0].device().index(); // Create a new FutureNCCL object after checking for single-process // single-device mode. return c10::make_intrusive( - at::IValue(*outputs_), (*outputs_)[0].device().index(), cudaEvents_); + at::IValue(*outputs_), + deviceIndex, + cudaEvents_, + futureNCCLCallbackStreams_[deviceIndex]); } template @@ -799,8 +814,10 @@ std::shared_ptr ProcessGroupNCCL::collective( // Work itself will create the CUDA events on all GPUs of tensors auto work = initWork(devices); - // Store a reference to outputs to be used by WorkNCCL::getFuture. + // Store references to outputs and futureNCCLCallbackStream to be used by + // WorkNCCL::getFuture. work->outputs_ = std::make_shared>(outputs); + work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_; at::cuda::OptionalCUDAGuard gpuGuard; diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 0481290ab5f4125..3ee8bb4adf63d4c 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -136,6 +136,9 @@ class ProcessGroupNCCL : public ProcessGroup { // Store a reference to NCCL collective's outputs to be used by getFuture. std::shared_ptr> outputs_; + // Store streams that run FutureNCCL then callbacks. + std::vector> + futureNCCLCallbackStreams_; friend class ProcessGroupNCCL; }; @@ -148,10 +151,12 @@ class ProcessGroupNCCL : public ProcessGroup { // or NCCL's barrier(). // // If created by WorkNCCL's getFuture API, FutureNCCL has a reference to - // WorkNCCL's cudaEvents, NCCL collective's outputs, and device index of - // outputs' device. Its value is NCCL collective's outputs. FutureNCCL - // only supports single-process single-device mode where the size of outputs - // is equal to 1. + // WorkNCCL's cudaEvents, NCCL collective's outputs, device index of + // outputs' device, and the ProcesGroupNCCL's dedicated + // futureNCCLCallbackStream for outputs' device that runs all the then + // callbacks called from this FutureNCCL. Its value is NCCL collective's + // outputs. FutureNCCL only supports single-process single-device mode where + // the size of outputs is equal to 1. // // If created by FutureNCCL's then callback, its value becomes the value of // callback() and its cudaEvents will record the NCCL stream that runs that @@ -160,17 +165,20 @@ class ProcessGroupNCCL : public ProcessGroup { // enables synchronizing the appropriate streams and avoids stalling PyTorch's // default stream while running the callback. In case of multiple then // callbacks, the design will work like a chain such that FutureNCCL n will - // wait on the cudaEvents from FutureNCCL n - 1. + // wait on the cudaEvents from FutureNCCL n - 1. All callbacks are executed on + // outputs' device's dedicated futureNCCLCallbackStream. struct FutureNCCL : at::ivalue::Future { public: explicit FutureNCCL( at::IValue value, c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents) + std::shared_ptr> cudaEvents, + std::shared_ptr futureNCCLCallbackStream) : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), value_(std::move(value)), deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents) { + cudaEvents_(cudaEvents), + futureNCCLCallbackStream_(futureNCCLCallbackStream) { TORCH_INTERNAL_ASSERT( cudaEvents_->size() == 1, "FutureNCCL only supports single-process single-device mode."); @@ -181,10 +189,12 @@ class ProcessGroupNCCL : public ProcessGroup { // return value of callback. explicit FutureNCCL( c10::DeviceIndex deviceIndex, - std::shared_ptr> cudaEvents) + std::shared_ptr> cudaEvents, + std::shared_ptr futureNCCLCallbackStream) : at::ivalue::Future(c10::ListType::create(c10::TensorType::get())), deviceIndex_(deviceIndex), - cudaEvents_(cudaEvents) { + cudaEvents_(cudaEvents), + futureNCCLCallbackStream_(futureNCCLCallbackStream) { TORCH_INTERNAL_ASSERT( cudaEvents_->size() == 1, "FutureNCCL only supports single-process single-device mode."); @@ -235,21 +245,11 @@ class ProcessGroupNCCL : public ProcessGroup { // synchronizing FutureNCCL's own cudaEvents with the stream that runs // this callback. This new FutureNCCL's cudaEvents will record the // callback's stream and will have the result value of the callback. - void addCallbackWithStream( - std::function callback, - const c10::cuda::CUDAStream& stream, - std::shared_ptr> thenFutCudaEvents) { - (*cudaEvents_)[0].block(stream); - c10::OptionalStreamGuard streamGuard{c10::Stream(stream)}; + void addCallback(std::function callback) override { + (*cudaEvents_)[0].block(*futureNCCLCallbackStream_); + c10::OptionalStreamGuard streamGuard{ + c10::Stream(*futureNCCLCallbackStream_)}; callback(); - (*thenFutCudaEvents)[0].record(stream); - } - - // We use addCallbackWithStream instead of addCallback. - void addCallback(std::function /* unused */) override { - C10_THROW_ERROR( - Error, - "FutureNCCL uses addCallbackWithStream instead of addCallback."); } // Adds a callback to FutureNCCL, and returns another FutureNCCL to hold @@ -258,32 +258,31 @@ class ProcessGroupNCCL : public ProcessGroup { c10::intrusive_ptr then( std::function callback, at::TypePtr /* unused */) override { - // Get a new stream from pool that will run the callback. - const c10::cuda::CUDAStream stream = - at::cuda::getStreamFromPool(deviceIndex_); - // Create a new cudaEvents object of size 1 that will record callback's - // stream and will be used by the new FutureNCCL. + // Create a new cudaEvents object of size 1 that will record + // futureNCCLCallbackStream_ after callback and will be passed to the new + // FutureNCCL. auto thenFutCudaEvents = std::make_shared>(1); // Create a FutureNCCL without setting a value. - auto fut = - c10::make_intrusive(deviceIndex_, thenFutCudaEvents); + auto fut = c10::make_intrusive( + deviceIndex_, thenFutCudaEvents, futureNCCLCallbackStream_); + // Use the dedicated callback stream to run callback. // Cannot move capture std::function in lambda, because it cannot deduce // the template type for std::function. Hence use std::bind to explicitly // specify types. - addCallbackWithStream( - std::bind( - [&](std::function cb) { - try { - fut->markCompleted(at::IValue(cb())); - } catch (const std::exception& e) { - fut->setError(e.what()); - } - }, - std::move(callback)), - stream, - thenFutCudaEvents); + addCallback(std::bind( + [&](std::function cb) { + try { + fut->markCompleted(at::IValue(cb())); + // In case of chained then callback calls, thenFutCudaEvents + // records callback's stream. + (*thenFutCudaEvents)[0].record(*futureNCCLCallbackStream_); + } catch (const std::exception& e) { + fut->setError(e.what()); + } + }, + std::move(callback))); return fut; } @@ -306,6 +305,7 @@ class ProcessGroupNCCL : public ProcessGroup { at::IValue value_; c10::DeviceIndex deviceIndex_; std::shared_ptr> cudaEvents_; + std::shared_ptr futureNCCLCallbackStream_; c10::optional error_; }; @@ -576,6 +576,10 @@ class ProcessGroupNCCL : public ProcessGroup { // for this map since only the watchdog thread accesses this set. The // set contains the string representation of ncclUniqueId. std::unordered_set abortedComms_; + + // Dedicated CUDA stream for each available device that runs FutureNCCL then + // callbacks. + std::vector> futureNCCLCallbackStreams_; }; } // namespace c10d diff --git a/torch/linalg/__init__.py b/torch/linalg/__init__.py index 61850fab1dba924..5e2b59c45c80714 100644 --- a/torch/linalg/__init__.py +++ b/torch/linalg/__init__.py @@ -13,3 +13,130 @@ Alias of :func:`torch.det`. """) + +norm = _add_docstr(_linalg.linalg_norm, r""" +linalg.norm(input, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor + +Returns the matrix norm or vector norm of a given tensor. + +This function can calculate one of eight different types of matrix norms, or one +of an infinite number of vector norms, depending on both the number of reduction +dimensions and the value of the `ord` parameter. + +Args: + input (Tensor): The input tensor. If dim is None, x must be 1-D or 2-D, unless :attr:`ord` + is None. If both :attr:`dim` and :attr:`ord` are None, the 2-norm of the input flattened to 1-D + will be returned. + + ord (int, float, inf, -inf, 'fro', 'nuc', optional): The order of norm. + inf refers to :attr:`float('inf')`, numpy's :attr:`inf` object, or any equivalent object. + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- not supported -- + 'nuc' nuclear norm -- not supported -- + inf max(sum(abs(x), dim=1)) max(abs(x)) + -inf min(sum(abs(x), dim=1)) min(abs(x)) + 0 -- not supported -- sum(x != 0) + 1 max(sum(abs(x), dim=0)) as below + -1 min(sum(abs(x), dim=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- not supported -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Default: ``None`` + + dim (int, 2-tuple of ints, 2-list of ints, optional): If :attr:`dim` is an int, + vector norm will be calculated over the specified dimension. If :attr:`dim` + is a 2-tuple of ints, matrix norm will be calculated over the specified + dimensions. If :attr:`dim` is None, matrix norm will be calculated + when the input tensor has two dimensions, and vector norm will be + calculated when the input tensor has one dimension. Default: ``None`` + + keepdim (bool, optional): If set to True, the reduced dimensions are retained + in the result as dimensions with size one. Default: ``False`` + +Keyword args: + + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + + dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to + :attr:`dtype` before performing the operation, and the returned tensor's type + will be :attr:`dtype`. If this argument is used in conjunction with the + :attr:`out` argument, the output tensor's type must match this argument or a + RuntimeError will be raised. This argument is not currently supported for + :attr:`ord='nuc'` or :attr:`ord='fro'`. Default: ``None`` + +Examples:: + + >>> import torch + >>> from torch import linalg as LA + >>> a = torch.arange(9, dtype=torch.float) - 4 + >>> a + tensor([-4., -3., -2., -1., 0., 1., 2., 3., 4.]) + >>> b = a.reshape((3, 3)) + >>> b + tensor([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]]) + + >>> LA.norm(a) + tensor(7.7460) + >>> LA.norm(b) + tensor(7.7460) + >>> LA.norm(b, 'fro') + tensor(7.7460) + >>> LA.norm(a, float('inf')) + tensor(4.) + >>> LA.norm(b, float('inf')) + tensor(9.) + >>> LA.norm(a, -float('inf')) + tensor(0.) + >>> LA.norm(b, -float('inf')) + tensor(2.) + + >>> LA.norm(a, 1) + tensor(20.) + >>> LA.norm(b, 1) + tensor(7.) + >>> LA.norm(a, -1) + tensor(0.) + >>> LA.norm(b, -1) + tensor(6.) + >>> LA.norm(a, 2) + tensor(7.7460) + >>> LA.norm(b, 2) + tensor(7.3485) + + >>> LA.norm(a, -2) + tensor(0.) + >>> LA.norm(b.double(), -2) + tensor(1.8570e-16, dtype=torch.float64) + >>> LA.norm(a, 3) + tensor(5.8480) + >>> LA.norm(a, -3) + tensor(0.) + +Using the :attr:`dim` argument to compute vector norms:: + + >>> c = torch.tensor([[1., 2., 3.], + ... [-1, 1, 4]]) + >>> LA.norm(c, dim=0) + tensor([1.4142, 2.2361, 5.0000]) + >>> LA.norm(c, dim=1) + tensor([3.7417, 4.2426]) + >>> LA.norm(c, ord=1, dim=1) + tensor([6., 6.]) + +Using the :attr:`dim` argument to compute matrix norms:: + + >>> m = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) + >>> LA.norm(m, dim=(1,2)) + tensor([ 3.7417, 11.2250]) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (tensor(3.7417), tensor(11.2250)) +""") diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 4ec89e4c9b0b855..0b8515d7f0163b2 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -230,6 +230,72 @@ def export(model, args, f, export_params=True, verbose=False, training=TrainingM custom_opsets, enable_onnx_checker, use_external_data_format) +def _diagnose_export(*args, **kwargs): + r""" + This diagnostic tool runs your model with operator_export_type set to + OperatorExportTypes.ONNX_FALLTHROUGH once in order to get a list of + all the ops that are not supported/implemented by the current exporter + + Arguments: + model (torch.nn.Module): the model to be exported. + args (tuple of arguments or torch.Tensor): the inputs to + the model, e.g., such that ``model(*args)`` is a valid + invocation of the model. Any non-Tensor arguments will + be hard-coded into the exported model; any Tensor arguments + will become inputs of the exported model, in the order they + occur in args. If args is a Tensor, this is equivalent + to having called it with a 1-ary tuple of that Tensor. + (Note: passing keyword arguments to the model is not currently + supported. Give us a shout if you need it.) + f: a file-like object (has to implement fileno that returns a file descriptor) + or a string containing a file name. A binary Protobuf will be written + to this file. + input_names(list of strings, default empty list): names to assign to the + input nodes of the graph, in order + output_names(list of strings, default empty list): names to assign to the + output nodes of the graph, in order + opset_version (int, default is 9): by default we export the model to the + opset version of the onnx submodule. Since ONNX's latest opset may + evolve before next stable release, by default we export to one stable + opset version. Right now, supported stable opset version is 9. + The opset_version must be _onnx_master_opset or in _onnx_stable_opsets + which are defined in torch/onnx/symbolic_helper.py + dynamic_axes (dict> or dict, default empty dict): + a dictionary to specify dynamic axes of input/output, such that: + - KEY: input and/or output names + - VALUE: index of dynamic axes for given key and potentially the name to be used for + exported dynamic axes. Similar behavior to dyanmic axes argument in export + + operator_export_type is set to OperatorExportTypes.ONNX_FALLTHROUGH by default + OperatorExportTypes.ONNX_FALLTHROUGH: If an op is not supported + in ONNX, fall through and export the operator as is, as a custom + ONNX op. Using this mode, the op can be exported and implemented by + the user for their runtime backend. + Example graph:: + + graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + %4 : None = prim::Constant() + %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + return (%5) + + is exported as:: + + graph(%0 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu)): + %6 : Long(requires_grad=0, device=cpu) = prim::Constant[value={0}]() + %4 : None = prim::Constant() + %5 : Float(2:12, 3:4, 4:1, requires_grad=0, device=cpu) = aten::cumsum(%0, %6, %4) # main.py:6:0 + return (%5) + + In the above example, aten::add with alpha != 1 is not supported and aten::cumsum in not + implemented in opset 9, hence exporter falls through and provides a list of unsupported ops, + the result being: + Unsupported ops : [aten:add, aten:cumsum] + """ + from torch.onnx import utils + result = utils._diagnose_export(*args, **kwargs) + return result + def export_to_pretty_string(*args, **kwargs): from torch.onnx import utils return utils.export_to_pretty_string(*args, **kwargs) diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 4ead70c3b394d33..03fbc40546abc33 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -122,11 +122,15 @@ def parse_args(*arg_descriptors): def decorator(fn): fn._arg_descriptors = arg_descriptors - def wrapper(g, *args): + def wrapper(g, *args, **kwargs): # some args may be optional, so the length may be smaller assert len(arg_descriptors) >= len(args) args = [_parse_arg(arg, arg_desc) for arg, arg_desc in zip(args, arg_descriptors)] - return fn(g, *args) + # only support _outputs in kwargs + assert len(kwargs) <= 1 + if len(kwargs) == 1: + assert '_outputs' in kwargs + return fn(g, *args, **kwargs) # In Python 2 functools.wraps chokes on partially applied functions, so we need this as a workaround try: wrapper = wraps(fn)(wrapper) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 9235b0216a510f5..b9ea02b31b11dcd 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -227,10 +227,15 @@ def scatter(g, self, dim, index, src): from torch.onnx.symbolic_opset9 import expand_as if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: return g.op("ATen", self, dim, index, src, operator_s="scatter") + src_type = src.type().scalarType() src = sym_help._maybe_get_scalar(src) if sym_help._is_value(src): return g.op("ScatterElements", self, index, src, axis_i=dim) else: + # Check if scalar 'src' has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if self.type().scalarType() != src_type: + src = g.op("Cast", src, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) return g.op("ScatterElements", self, index, expand_as(g, src, index), axis_i=dim) diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 4e065247869dd33..6abbde0bcf51330 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -73,22 +73,22 @@ def argmax(g, input, dim, keepdim): if sym_help._is_none(dim): from torch.onnx.symbolic_opset9 import reshape flattened = reshape(g, input, (-1,)) - return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True) + return g.op('ArgMax', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, 'i') keepdim = _parse_arg(keepdim, 'i') - return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True) + return g.op('ArgMax', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False) def argmin(g, input, dim, keepdim): if sym_help._is_none(dim): from torch.onnx.symbolic_opset9 import reshape flattened = reshape(g, input, (-1,)) - return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=True) + return g.op('ArgMin', flattened, axis_i=0, keepdims_i=False, select_last_index_i=False) else: dim = _parse_arg(dim, 'i') keepdim = _parse_arg(keepdim, 'i') - return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=True) + return g.op('ArgMin', input, axis_i=dim, keepdims_i=keepdim, select_last_index_i=False) def pow(g, self, exponent): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index a52419269727da7..1ec83634332e809 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -997,6 +997,10 @@ def gt(g, input, other): def gt_impl(g, input, other): + if input.type().scalarType() is not None and input.type().scalarType() == 'Bool' and \ + other.type().scalarType() is not None and other.type().scalarType() == 'Bool': + input = g.op("Cast", input, to_i=sym_help.cast_pytorch_to_onnx['Int']) + other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Int']) return g.op("Greater", input, other) @@ -1005,6 +1009,10 @@ def lt(g, input, other): def lt_impl(g, input, other): + if input.type().scalarType() is not None and input.type().scalarType() == 'Bool' and \ + other.type().scalarType() is not None and other.type().scalarType() == 'Bool': + input = g.op("Cast", input, to_i=sym_help.cast_pytorch_to_onnx['Int']) + other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Int']) return g.op("Less", input, other) @@ -1060,10 +1068,14 @@ def __lshift_(g, self, other): return lshift -def where(g, condition, self, other): +@parse_args('v', 'v', 'v', 'i') +def where(g, condition, self=None, other=None, _outputs=None): # Assumes that torch.where's first argument takes only Bool and Byte tensors. - if condition.type().scalarType() != 'Bool': + if condition.type().scalarType() != 'Bool': condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx['Bool']) + if self is None: + condition = torch.onnx.symbolic_opset9.nonzero(g, condition) + return unbind(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs) return g.op("Where", condition, self, other) @@ -2115,10 +2127,15 @@ def argmin(g, input, dim, keepdim): @parse_args('v', 'i', 'v', 'v') def scatter(g, self, dim, index, src): + src_type = src.type().scalarType() src = sym_help._maybe_get_scalar(src) if sym_help._is_value(src): return g.op("Scatter", self, index, src, axis_i=dim) else: + # Check if scalar 'src' has same type as self (PyTorch allows different + # type for scalar src (but not when src is tensor)). If not, insert Cast node. + if self.type().scalarType() != src_type: + src = g.op("Cast", src, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()]) return g.op("Scatter", self, index, expand_as(g, src, index), axis_i=dim) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f51adfe1b128e59..77d179306223f72 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -492,6 +492,24 @@ def _export_to_pretty_string(model, args, f, export_params=True, verbose=False, operator_export_type, google_printer, val_keep_init_as_ip, custom_opsets, val_add_node_names) +def _diagnose_export(model, args, f, verbose=False, training=TrainingMode.EVAL, + input_names=None, output_names=None, opset_version=None, dynamic_axes=None): + from torch.onnx.symbolic_helper import _default_onnx_opset_version, _set_opset_version + if opset_version is None: + opset_version = _default_onnx_opset_version + _set_opset_version(opset_version) + # operator_export_type is set ro ONNX_FALLTHROUGH by default so that if an op is not supported + # in ONNX, fall through will occur and export the operator as is, as a custom ONNX op. + operator_export_type = OperatorExportTypes.ONNX_FALLTHROUGH + with select_model_mode_for_export(model, training): + graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names, + output_names, operator_export_type) + # The output 'unsupported_ops' will contain the names of all the ops that are not supported in ONNX + unsupported_ops = list() + for node in graph.nodes(): + if node.kind().split(':')[0] not in ['onnx', 'prim']: + unsupported_ops.append(node.kind()) + return graph, unsupported_ops # NOTE: the output `torch_out` will contain the output tensors resulting from # the trace of a Module. In the case that a torch.nn.ScriptModule is passed in, diff --git a/torch/overrides.py b/torch/overrides.py index 168652f8a898f61..6532fd4b257d60a 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -229,10 +229,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.asin: lambda input, out=None: -1, torch.arcsin: lambda input, out=None: -1, torch.asinh: lambda input, out=None: -1, + torch.arcsinh: lambda input, out=None: -1, torch.atan: lambda input, out=None: -1, torch.arctan: lambda input, out=None: -1, torch.atan2: lambda input, other, out=None: -1, torch.atanh: lambda input, out=None: -1, + torch.arctanh: lambda input, out=None: -1, torch.atleast_1d: lambda input: -1, torch.atleast_2d: lambda input: -1, torch.atleast_3d: lambda input: -1, @@ -672,6 +674,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.roll: lambda input, shifts, dims=None: -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1, torch.round: lambda input, out=None: -1, + torch.rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1), torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1, torch.rsqrt: lambda input, out=None: -1, torch.rsub: lambda input, other, alpha=1: -1, @@ -704,6 +707,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True: -1), torch.sub: lambda input, other, out=None: -1, + torch.subtract: lambda input, other, out=None: -1, torch.sum: lambda input, dim=None: -1, torch.nansum: lambda input, dim=None: -1, torch.svd: lambda input, some=True, compute_uv=True, out=None: -1, diff --git a/torch/quantization/fuse_modules.py b/torch/quantization/fuse_modules.py index 3902315b9f3b71d..56aabdc1b37918a 100644 --- a/torch/quantization/fuse_modules.py +++ b/torch/quantization/fuse_modules.py @@ -6,6 +6,8 @@ import torch.nn as nn import torch.nn.intrinsic as nni +from typing import Type, List, Optional, Union, Callable, Tuple, Dict + def fuse_conv_bn(conv, bn): r"""Given the conv and bn modules, fuses them and returns the fused module @@ -48,6 +50,7 @@ def fuse_conv_bn_relu(conv, bn, relu): """ assert(conv.training == bn.training == relu.training),\ "Conv and BN both must be in the same mode (train or eval)." + fused_module : Optional[Type[nn.Sequential]] = None if conv.training: map_to_fused_module_train = { nn.Conv2d: torch_fused.ConvBnReLU2d, @@ -69,11 +72,12 @@ def fuse_conv_bn_relu(conv, bn, relu): } fused_module = map_to_fused_module_eval[type(conv)] if fused_module is not None: - return fused_module(nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu) + fused_conv = nn.utils.fusion.fuse_conv_bn_eval(conv, bn) + return fused_module(fused_conv, relu) else: raise NotImplementedError("Cannot fuse eval modules: {}".format((conv, bn, relu))) -OP_LIST_TO_FUSER_METHOD = { +OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = { (nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn, (nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu, (nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn, @@ -119,24 +123,26 @@ def fuse_known_modules(mod_list): the fused operation. The rest of the elements are set to nn.Identity() """ types = tuple(type(m) for m in mod_list) - fuser_method = OP_LIST_TO_FUSER_METHOD.get(types, None) + fuser_method = OP_LIST_TO_FUSER_METHOD.get(types) if fuser_method is None: raise NotImplementedError("Cannot fuse modules: {}".format(types)) - new_mod = [None] * len(mod_list) - new_mod[0] = fuser_method(*mod_list) + new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) + fused = fuser_method(*mod_list) # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion # Move pre forward hooks of the base module to resulting fused module for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items(): - new_mod[0].register_forward_pre_hook(pre_hook_fn) + fused.register_forward_pre_hook(pre_hook_fn) del mod_list[0]._forward_pre_hooks[handle_id] # Move post forward hooks of the last module to resulting fused module for handle_id, hook_fn in mod_list[-1]._forward_hooks.items(): - new_mod[0].register_forward_hook(hook_fn) + fused.register_forward_hook(hook_fn) del mod_list[-1]._forward_hooks[handle_id] + new_mod[0] = fused for i in range(1, len(mod_list)): - new_mod[i] = nn.Identity() - new_mod[i].training = mod_list[0].training + identity = nn.Identity() + identity.training = mod_list[0].training + new_mod[i] = identity return new_mod diff --git a/torch/quantization/fx/fuse.py b/torch/quantization/fx/fuse.py index 4376548b3945474..9f187f1a2fd32e8 100644 --- a/torch/quantization/fx/fuse.py +++ b/torch/quantization/fx/fuse.py @@ -61,7 +61,6 @@ def fuse(self, quantizer, load_arg): op_list.append(relu) relu.training = self.conv.training if self.bn_node is not None: - setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn_relu(self.conv, self.bn, relu)) op_list.append(self.bn) op_list.append(self.conv) else: @@ -69,6 +68,8 @@ def fuse(self, quantizer, load_arg): op_list.append(self.bn) op_list.append(self.conv) + # the modules are added in order of relu - bn - conv + # so we need to correct it op_list.reverse() op_type_list = tuple(type(m) for m in op_list) conv_parent_name, conv_name = _parent_name(self.conv_node.target) @@ -131,8 +132,8 @@ def fuse_conv_bn(self, model, inplace=False): self.modules = dict(input_root.named_modules()) fusion_patterns = get_fusion_patterns() - # find conv-bn pairs - conv_bn_pairs = self._find_matches(input_root, input_graph, fusion_patterns) + # find fusion + fusion_pairs = self._find_matches(input_root, input_graph, fusion_patterns) self.fused_graph = Graph() env = {} @@ -140,7 +141,7 @@ def load_arg(a): return map_arg(a, lambda node: env[node.name]) for node in input_graph.nodes: - root_node, obj = conv_bn_pairs.get(node.name, (None, None)) + root_node, obj = fusion_pairs.get(node.name, (None, None)) if root_node is node: env[node.name] = obj.fuse(self, load_arg) elif root_node is None: @@ -161,7 +162,9 @@ def apply_match(pattern, node, match): for subpattern, arg in zip(args, node.args): apply_match(subpattern, arg, match) else: - match_map[node.name] = match + # the first pattern matches will take precedence + if node.name not in match_map: + match_map[node.name] = match for node in reversed(graph.nodes): if node.name not in match_map: diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 447257e8f9d9b34..4253059e9437c43 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -884,7 +884,7 @@ def get_next_i(module, qparams): qparam_full_path = key + str(i) if parent_name: qparam_full_path = parent_name + '.' + qparam_full_path - inputs.append(self.quantized_graph.get_param(qparam_full_path)) + inputs.append(self.quantized_graph.create_node('get_param', qparam_full_path)) quant_env[node.name] = self.quantized_graph.create_node('call_function', torch.quantize_per_tensor, inputs, {}) continue # dequantize inputs for the node that are not quantized diff --git a/torch/quantization/quantize.py b/torch/quantization/quantize.py index dc60d2580a3414d..f4c45c86366eb79 100644 --- a/torch/quantization/quantize.py +++ b/torch/quantization/quantize.py @@ -261,7 +261,7 @@ def quantize_dynamic(model, qconfig_spec=None, dtype=torch.qint8, If `qconfig` is provided, the `dtype` argument is ignored. Args: - module: input model + model: input model qconfig_spec: Either: - A dictionary that maps from name or type of submodule to quantization diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index 40108dd3725a72d..396d0718efbcf6a 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -313,20 +313,25 @@ def all_types_and_complex_and(*dtypes): def all_types_and_half(): return _all_types_and_half -def get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=True, include_complex=True) -> List[torch.dtype]: +def get_all_dtypes(include_half=True, + include_bfloat16=True, + include_bool=True, + include_complex=True, + include_complex32=False + ) -> List[torch.dtype]: dtypes = get_all_int_dtypes() + get_all_fp_dtypes(include_half=include_half, include_bfloat16=include_bfloat16) if include_bool: dtypes.append(torch.bool) if include_complex: - dtypes += get_all_complex_dtypes() + dtypes += get_all_complex_dtypes(include_complex32) return dtypes def get_all_math_dtypes(device) -> List[torch.dtype]: return get_all_int_dtypes() + get_all_fp_dtypes(include_half=device.startswith('cuda'), include_bfloat16=False) + get_all_complex_dtypes() -def get_all_complex_dtypes() -> List[torch.dtype]: - return [torch.complex64, torch.complex128] +def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]: + return [torch.complex32, torch.complex64, torch.complex128] if include_complex32 else [torch.complex64, torch.complex128] def get_all_int_dtypes() -> List[torch.dtype]: diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 002e5e51d054e47..d8b402ae4ad7106 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -7,13 +7,12 @@ checking quantization api and properties of resulting modules. """ -import copy -import io -import functools import torch import torch.nn as nn import torch.nn.quantized as nnq import torch.nn.quantized.dynamic as nnqd +import torch.distributed as dist + from torch.testing._internal.common_utils import TestCase from torch.quantization import QuantWrapper, QuantStub, DeQuantStub, \ default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ @@ -33,6 +32,12 @@ fuse, ) +import copy +import io +import functools +import time +import os + import unittest import numpy as np from torch.testing import FileCheck @@ -101,6 +106,85 @@ def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): correct += (predicted == target).sum().item() return train_loss, correct, total +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + +def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): + model.train() + cnt = 0 + for image, target in data_loader: + start_time = time.time() + print('.', end='') + cnt += 1 + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + if cnt >= ntrain_batches: + return + return + +def ddp_setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + +def ddp_cleanup(): + dist.destroy_process_group() + +def run_ddp(rank, world_size, prepared): + ddp_setup(rank, world_size) + prepared.cuda() + prepared = torch.nn.parallel.DistributedDataParallel(prepared, device_ids=[rank]) + prepared.to(rank) + model_with_ddp = prepared + optimizer = torch.optim.SGD(model_with_ddp.parameters(), lr=0.0001) + train_one_epoch(model_with_ddp, criterion, optimizer, dataset, rank, 1) + ddp_cleanup() + + def convert_dynamic(module): convert(module, DEFAULT_DYNAMIC_MODULE_MAPPING, inplace=True) @@ -176,6 +260,13 @@ def wrapper(*args, **kwargs): fn(*args, **kwargs) return wrapper +try: + import torchvision # noqa: F401 + HAS_TORCHVISION = True +except ImportError: + HAS_TORCHVISION = False +skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + def get_script_module(model, tracing, data): return torch.jit.trace(model, data) if tracing else torch.jit.script(model) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index d1fbd59e362d052..36ba1db92d38503 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -389,7 +389,7 @@ def wrapper(*args, **kwargs): # This decorator can be used for API tests that call torch.set_deterministic(). # When the test is finished, it will restore the previous deterministic flag -# setting. Also, if CUDA >= 10.2, this will set the environment variable +# setting. Also, if CUDA >= 10.2, this will set the environment variable # CUBLAS_WORKSPACE_CONFIG=:4096:8 so that the error associated with that setting # is not thrown during the test unless the test changes that variable on purpose. # The previous CUBLAS_WORKSPACE_CONFIG setting will also be restored once the diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index d2e783a3abd3bb5..542c182b036473d 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -675,3 +675,15 @@ def get_module_method(m, module, method): def attrs_with_prefix(module, prefix): return [x for x, _ in module._modules._c.items() if x.startswith(prefix)] + +def warmup_backward(f, *args): + profiling_count = 2 + results = [] + for i in range(profiling_count): + if len(args) > 0: + r = torch.autograd.grad(f, *args) + results.append(r) + else: + f.backward(retain_graph=True) + + return results